| @@ -181,12 +181,12 @@ func runServ(c *cli.Context) { | |||||
| if requestedMode == models.ACCESS_MODE_WRITE || repo.IsPrivate { | if requestedMode == models.ACCESS_MODE_WRITE || repo.IsPrivate { | ||||
| keys := strings.Split(c.Args()[0], "-") | keys := strings.Split(c.Args()[0], "-") | ||||
| if len(keys) != 2 { | if len(keys) != 2 { | ||||
| fail("Key ID format error", "Invalid key ID: %s", c.Args()[0]) | |||||
| fail("Key ID format error", "Invalid key argument: %s", c.Args()[0]) | |||||
| } | } | ||||
| key, err := models.GetPublicKeyByID(com.StrTo(keys[1]).MustInt64()) | key, err := models.GetPublicKeyByID(com.StrTo(keys[1]).MustInt64()) | ||||
| if err != nil { | if err != nil { | ||||
| fail("Key ID format error", "Invalid key ID[%s]: %v", c.Args()[0], err) | |||||
| fail("Invalid key ID", "Invalid key ID[%s]: %v", c.Args()[0], err) | |||||
| } | } | ||||
| keyID = key.ID | keyID = key.ID | ||||
| @@ -48,6 +48,8 @@ HTTP_ADDR = | |||||
| HTTP_PORT = 3000 | HTTP_PORT = 3000 | ||||
| ; Disable SSH feature when not available | ; Disable SSH feature when not available | ||||
| DISABLE_SSH = false | DISABLE_SSH = false | ||||
| ; Whether use builtin SSH server or not. | |||||
| START_SSH_SERVER = false | |||||
| SSH_PORT = 22 | SSH_PORT = 22 | ||||
| ; Disable CDN even in "prod" mode | ; Disable CDN even in "prod" mode | ||||
| OFFLINE_MODE = false | OFFLINE_MODE = false | ||||
| @@ -13,7 +13,6 @@ import ( | |||||
| "io" | "io" | ||||
| "io/ioutil" | "io/ioutil" | ||||
| "os" | "os" | ||||
| "os/exec" | |||||
| "path" | "path" | ||||
| "path/filepath" | "path/filepath" | ||||
| "strings" | "strings" | ||||
| @@ -38,20 +37,7 @@ var ( | |||||
| ) | ) | ||||
| var sshOpLocker = sync.Mutex{} | var sshOpLocker = sync.Mutex{} | ||||
| var ( | |||||
| SSHPath string // SSH directory. | |||||
| appPath string // Execution(binary) path. | |||||
| ) | |||||
| // exePath returns the executable path. | |||||
| func exePath() (string, error) { | |||||
| file, err := exec.LookPath(os.Args[0]) | |||||
| if err != nil { | |||||
| return "", err | |||||
| } | |||||
| return filepath.Abs(file) | |||||
| } | |||||
| var SSHPath string // SSH directory. | |||||
| // homeDir returns the home directory of current user. | // homeDir returns the home directory of current user. | ||||
| func homeDir() string { | func homeDir() string { | ||||
| @@ -63,16 +49,9 @@ func homeDir() string { | |||||
| } | } | ||||
| func init() { | func init() { | ||||
| var err error | |||||
| if appPath, err = exePath(); err != nil { | |||||
| log.Fatal(4, "fail to get app path: %v\n", err) | |||||
| } | |||||
| appPath = strings.Replace(appPath, "\\", "/", -1) | |||||
| // Determine and create .ssh path. | // Determine and create .ssh path. | ||||
| SSHPath = filepath.Join(homeDir(), ".ssh") | SSHPath = filepath.Join(homeDir(), ".ssh") | ||||
| if err = os.MkdirAll(SSHPath, 0700); err != nil { | |||||
| if err := os.MkdirAll(SSHPath, 0700); err != nil { | |||||
| log.Fatal(4, "fail to create '%s': %v", SSHPath, err) | log.Fatal(4, "fail to create '%s': %v", SSHPath, err) | ||||
| } | } | ||||
| } | } | ||||
| @@ -114,7 +93,7 @@ func (k *PublicKey) OmitEmail() string { | |||||
| // GetAuthorizedString generates and returns formatted public key string for authorized_keys file. | // GetAuthorizedString generates and returns formatted public key string for authorized_keys file. | ||||
| func (key *PublicKey) GetAuthorizedString() string { | func (key *PublicKey) GetAuthorizedString() string { | ||||
| return fmt.Sprintf(_TPL_PUBLICK_KEY, appPath, key.ID, setting.CustomConf, key.Content) | |||||
| return fmt.Sprintf(_TPL_PUBLICK_KEY, setting.AppPath, key.ID, setting.CustomConf, key.Content) | |||||
| } | } | ||||
| func extractTypeFromBase64Key(key string) (string, error) { | func extractTypeFromBase64Key(key string) (string, error) { | ||||
| @@ -373,6 +352,19 @@ func GetPublicKeyByID(keyID int64) (*PublicKey, error) { | |||||
| return key, nil | return key, nil | ||||
| } | } | ||||
| // SearchPublicKeyByContent searches content as prefix (leak e-mail part) | |||||
| // and returns public key found. | |||||
| func SearchPublicKeyByContent(content string) (*PublicKey, error) { | |||||
| key := new(PublicKey) | |||||
| has, err := x.Where("content like ?", content+"%").Get(key) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } else if !has { | |||||
| return nil, ErrKeyNotExist{} | |||||
| } | |||||
| return key, nil | |||||
| } | |||||
| // ListPublicKeys returns a list of public keys belongs to given user. | // ListPublicKeys returns a list of public keys belongs to given user. | ||||
| func ListPublicKeys(uid int64) ([]*PublicKey, error) { | func ListPublicKeys(uid int64) ([]*PublicKey, error) { | ||||
| keys := make([]*PublicKey, 0, 5) | keys := make([]*PublicKey, 0, 5) | ||||
| @@ -380,7 +380,7 @@ func (repo *Repository) CloneLink() (cl CloneLink, err error) { | |||||
| } | } | ||||
| if setting.SSHPort != 22 { | if setting.SSHPort != 22 { | ||||
| cl.SSH = fmt.Sprintf("ssh://%s@%s:%d/%s/%s.git", setting.RunUser, setting.SSHDomain, setting.SSHPort, repo.Owner.LowerName, repo.LowerName) | |||||
| cl.SSH = fmt.Sprintf("ssh://%s@%s:%d/%s/%s.git", setting.RunUser, setting.SSHDomain, setting.SSHPort, repo.Owner.Name, repo.Name) | |||||
| } else { | } else { | ||||
| cl.SSH = fmt.Sprintf("%s@%s:%s/%s.git", setting.RunUser, setting.SSHDomain, repo.Owner.Name, repo.Name) | cl.SSH = fmt.Sprintf("%s@%s:%s/%s.git", setting.RunUser, setting.SSHDomain, repo.Owner.Name, repo.Name) | ||||
| } | } | ||||
| @@ -599,7 +599,7 @@ func createUpdateHook(repoPath string) error { | |||||
| hookPath := path.Join(repoPath, "hooks/update") | hookPath := path.Join(repoPath, "hooks/update") | ||||
| os.MkdirAll(path.Dir(hookPath), os.ModePerm) | os.MkdirAll(path.Dir(hookPath), os.ModePerm) | ||||
| return ioutil.WriteFile(hookPath, | return ioutil.WriteFile(hookPath, | ||||
| []byte(fmt.Sprintf(_TPL_UPDATE_HOOK, setting.ScriptType, "\""+appPath+"\"", setting.CustomConf)), 0777) | |||||
| []byte(fmt.Sprintf(_TPL_UPDATE_HOOK, setting.ScriptType, "\""+setting.AppPath+"\"", setting.CustomConf)), 0777) | |||||
| } | } | ||||
| type CreateRepoOptions struct { | type CreateRepoOptions struct { | ||||
| @@ -4,6 +4,12 @@ | |||||
| package base | package base | ||||
| import ( | |||||
| "os" | |||||
| "os/exec" | |||||
| "path/filepath" | |||||
| ) | |||||
| const DOC_URL = "https://github.com/gogits/go-gogs-client/wiki" | const DOC_URL = "https://github.com/gogits/go-gogs-client/wiki" | ||||
| type ( | type ( | ||||
| @@ -11,3 +17,16 @@ type ( | |||||
| ) | ) | ||||
| var GoGetMetas = make(map[string]bool) | var GoGetMetas = make(map[string]bool) | ||||
| // ExecPath returns the executable path. | |||||
| func ExecPath() (string, error) { | |||||
| file, err := exec.LookPath(os.Args[0]) | |||||
| if err != nil { | |||||
| return "", err | |||||
| } | |||||
| p, err := filepath.Abs(file) | |||||
| if err != nil { | |||||
| return "", err | |||||
| } | |||||
| return p, nil | |||||
| } | |||||
| @@ -1,615 +0,0 @@ | |||||
| // Copyright 2012 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| /* | |||||
| Package agent implements a client to an ssh-agent daemon. | |||||
| References: | |||||
| [PROTOCOL.agent]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent?rev=HEAD | |||||
| */ | |||||
| package agent | |||||
| import ( | |||||
| "bytes" | |||||
| "crypto/dsa" | |||||
| "crypto/ecdsa" | |||||
| "crypto/elliptic" | |||||
| "crypto/rsa" | |||||
| "encoding/base64" | |||||
| "encoding/binary" | |||||
| "errors" | |||||
| "fmt" | |||||
| "io" | |||||
| "math/big" | |||||
| "sync" | |||||
| "github.com/gogits/gogs/modules/crypto/ssh" | |||||
| ) | |||||
| // Agent represents the capabilities of an ssh-agent. | |||||
| type Agent interface { | |||||
| // List returns the identities known to the agent. | |||||
| List() ([]*Key, error) | |||||
| // Sign has the agent sign the data using a protocol 2 key as defined | |||||
| // in [PROTOCOL.agent] section 2.6.2. | |||||
| Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) | |||||
| // Add adds a private key to the agent. | |||||
| Add(key AddedKey) error | |||||
| // Remove removes all identities with the given public key. | |||||
| Remove(key ssh.PublicKey) error | |||||
| // RemoveAll removes all identities. | |||||
| RemoveAll() error | |||||
| // Lock locks the agent. Sign and Remove will fail, and List will empty an empty list. | |||||
| Lock(passphrase []byte) error | |||||
| // Unlock undoes the effect of Lock | |||||
| Unlock(passphrase []byte) error | |||||
| // Signers returns signers for all the known keys. | |||||
| Signers() ([]ssh.Signer, error) | |||||
| } | |||||
| // AddedKey describes an SSH key to be added to an Agent. | |||||
| type AddedKey struct { | |||||
| // PrivateKey must be a *rsa.PrivateKey, *dsa.PrivateKey or | |||||
| // *ecdsa.PrivateKey, which will be inserted into the agent. | |||||
| PrivateKey interface{} | |||||
| // Certificate, if not nil, is communicated to the agent and will be | |||||
| // stored with the key. | |||||
| Certificate *ssh.Certificate | |||||
| // Comment is an optional, free-form string. | |||||
| Comment string | |||||
| // LifetimeSecs, if not zero, is the number of seconds that the | |||||
| // agent will store the key for. | |||||
| LifetimeSecs uint32 | |||||
| // ConfirmBeforeUse, if true, requests that the agent confirm with the | |||||
| // user before each use of this key. | |||||
| ConfirmBeforeUse bool | |||||
| } | |||||
| // See [PROTOCOL.agent], section 3. | |||||
| const ( | |||||
| agentRequestV1Identities = 1 | |||||
| // 3.2 Requests from client to agent for protocol 2 key operations | |||||
| agentAddIdentity = 17 | |||||
| agentRemoveIdentity = 18 | |||||
| agentRemoveAllIdentities = 19 | |||||
| agentAddIdConstrained = 25 | |||||
| // 3.3 Key-type independent requests from client to agent | |||||
| agentAddSmartcardKey = 20 | |||||
| agentRemoveSmartcardKey = 21 | |||||
| agentLock = 22 | |||||
| agentUnlock = 23 | |||||
| agentAddSmartcardKeyConstrained = 26 | |||||
| // 3.7 Key constraint identifiers | |||||
| agentConstrainLifetime = 1 | |||||
| agentConstrainConfirm = 2 | |||||
| ) | |||||
| // maxAgentResponseBytes is the maximum agent reply size that is accepted. This | |||||
| // is a sanity check, not a limit in the spec. | |||||
| const maxAgentResponseBytes = 16 << 20 | |||||
| // Agent messages: | |||||
| // These structures mirror the wire format of the corresponding ssh agent | |||||
| // messages found in [PROTOCOL.agent]. | |||||
| // 3.4 Generic replies from agent to client | |||||
| const agentFailure = 5 | |||||
| type failureAgentMsg struct{} | |||||
| const agentSuccess = 6 | |||||
| type successAgentMsg struct{} | |||||
| // See [PROTOCOL.agent], section 2.5.2. | |||||
| const agentRequestIdentities = 11 | |||||
| type requestIdentitiesAgentMsg struct{} | |||||
| // See [PROTOCOL.agent], section 2.5.2. | |||||
| const agentIdentitiesAnswer = 12 | |||||
| type identitiesAnswerAgentMsg struct { | |||||
| NumKeys uint32 `sshtype:"12"` | |||||
| Keys []byte `ssh:"rest"` | |||||
| } | |||||
| // See [PROTOCOL.agent], section 2.6.2. | |||||
| const agentSignRequest = 13 | |||||
| type signRequestAgentMsg struct { | |||||
| KeyBlob []byte `sshtype:"13"` | |||||
| Data []byte | |||||
| Flags uint32 | |||||
| } | |||||
| // See [PROTOCOL.agent], section 2.6.2. | |||||
| // 3.6 Replies from agent to client for protocol 2 key operations | |||||
| const agentSignResponse = 14 | |||||
| type signResponseAgentMsg struct { | |||||
| SigBlob []byte `sshtype:"14"` | |||||
| } | |||||
| type publicKey struct { | |||||
| Format string | |||||
| Rest []byte `ssh:"rest"` | |||||
| } | |||||
| // Key represents a protocol 2 public key as defined in | |||||
| // [PROTOCOL.agent], section 2.5.2. | |||||
| type Key struct { | |||||
| Format string | |||||
| Blob []byte | |||||
| Comment string | |||||
| } | |||||
| func clientErr(err error) error { | |||||
| return fmt.Errorf("agent: client error: %v", err) | |||||
| } | |||||
| // String returns the storage form of an agent key with the format, base64 | |||||
| // encoded serialized key, and the comment if it is not empty. | |||||
| func (k *Key) String() string { | |||||
| s := string(k.Format) + " " + base64.StdEncoding.EncodeToString(k.Blob) | |||||
| if k.Comment != "" { | |||||
| s += " " + k.Comment | |||||
| } | |||||
| return s | |||||
| } | |||||
| // Type returns the public key type. | |||||
| func (k *Key) Type() string { | |||||
| return k.Format | |||||
| } | |||||
| // Marshal returns key blob to satisfy the ssh.PublicKey interface. | |||||
| func (k *Key) Marshal() []byte { | |||||
| return k.Blob | |||||
| } | |||||
| // Verify satisfies the ssh.PublicKey interface, but is not | |||||
| // implemented for agent keys. | |||||
| func (k *Key) Verify(data []byte, sig *ssh.Signature) error { | |||||
| return errors.New("agent: agent key does not know how to verify") | |||||
| } | |||||
| type wireKey struct { | |||||
| Format string | |||||
| Rest []byte `ssh:"rest"` | |||||
| } | |||||
| func parseKey(in []byte) (out *Key, rest []byte, err error) { | |||||
| var record struct { | |||||
| Blob []byte | |||||
| Comment string | |||||
| Rest []byte `ssh:"rest"` | |||||
| } | |||||
| if err := ssh.Unmarshal(in, &record); err != nil { | |||||
| return nil, nil, err | |||||
| } | |||||
| var wk wireKey | |||||
| if err := ssh.Unmarshal(record.Blob, &wk); err != nil { | |||||
| return nil, nil, err | |||||
| } | |||||
| return &Key{ | |||||
| Format: wk.Format, | |||||
| Blob: record.Blob, | |||||
| Comment: record.Comment, | |||||
| }, record.Rest, nil | |||||
| } | |||||
| // client is a client for an ssh-agent process. | |||||
| type client struct { | |||||
| // conn is typically a *net.UnixConn | |||||
| conn io.ReadWriter | |||||
| // mu is used to prevent concurrent access to the agent | |||||
| mu sync.Mutex | |||||
| } | |||||
| // NewClient returns an Agent that talks to an ssh-agent process over | |||||
| // the given connection. | |||||
| func NewClient(rw io.ReadWriter) Agent { | |||||
| return &client{conn: rw} | |||||
| } | |||||
| // call sends an RPC to the agent. On success, the reply is | |||||
| // unmarshaled into reply and replyType is set to the first byte of | |||||
| // the reply, which contains the type of the message. | |||||
| func (c *client) call(req []byte) (reply interface{}, err error) { | |||||
| c.mu.Lock() | |||||
| defer c.mu.Unlock() | |||||
| msg := make([]byte, 4+len(req)) | |||||
| binary.BigEndian.PutUint32(msg, uint32(len(req))) | |||||
| copy(msg[4:], req) | |||||
| if _, err = c.conn.Write(msg); err != nil { | |||||
| return nil, clientErr(err) | |||||
| } | |||||
| var respSizeBuf [4]byte | |||||
| if _, err = io.ReadFull(c.conn, respSizeBuf[:]); err != nil { | |||||
| return nil, clientErr(err) | |||||
| } | |||||
| respSize := binary.BigEndian.Uint32(respSizeBuf[:]) | |||||
| if respSize > maxAgentResponseBytes { | |||||
| return nil, clientErr(err) | |||||
| } | |||||
| buf := make([]byte, respSize) | |||||
| if _, err = io.ReadFull(c.conn, buf); err != nil { | |||||
| return nil, clientErr(err) | |||||
| } | |||||
| reply, err = unmarshal(buf) | |||||
| if err != nil { | |||||
| return nil, clientErr(err) | |||||
| } | |||||
| return reply, err | |||||
| } | |||||
| func (c *client) simpleCall(req []byte) error { | |||||
| resp, err := c.call(req) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| if _, ok := resp.(*successAgentMsg); ok { | |||||
| return nil | |||||
| } | |||||
| return errors.New("agent: failure") | |||||
| } | |||||
| func (c *client) RemoveAll() error { | |||||
| return c.simpleCall([]byte{agentRemoveAllIdentities}) | |||||
| } | |||||
| func (c *client) Remove(key ssh.PublicKey) error { | |||||
| req := ssh.Marshal(&agentRemoveIdentityMsg{ | |||||
| KeyBlob: key.Marshal(), | |||||
| }) | |||||
| return c.simpleCall(req) | |||||
| } | |||||
| func (c *client) Lock(passphrase []byte) error { | |||||
| req := ssh.Marshal(&agentLockMsg{ | |||||
| Passphrase: passphrase, | |||||
| }) | |||||
| return c.simpleCall(req) | |||||
| } | |||||
| func (c *client) Unlock(passphrase []byte) error { | |||||
| req := ssh.Marshal(&agentUnlockMsg{ | |||||
| Passphrase: passphrase, | |||||
| }) | |||||
| return c.simpleCall(req) | |||||
| } | |||||
| // List returns the identities known to the agent. | |||||
| func (c *client) List() ([]*Key, error) { | |||||
| // see [PROTOCOL.agent] section 2.5.2. | |||||
| req := []byte{agentRequestIdentities} | |||||
| msg, err := c.call(req) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| switch msg := msg.(type) { | |||||
| case *identitiesAnswerAgentMsg: | |||||
| if msg.NumKeys > maxAgentResponseBytes/8 { | |||||
| return nil, errors.New("agent: too many keys in agent reply") | |||||
| } | |||||
| keys := make([]*Key, msg.NumKeys) | |||||
| data := msg.Keys | |||||
| for i := uint32(0); i < msg.NumKeys; i++ { | |||||
| var key *Key | |||||
| var err error | |||||
| if key, data, err = parseKey(data); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| keys[i] = key | |||||
| } | |||||
| return keys, nil | |||||
| case *failureAgentMsg: | |||||
| return nil, errors.New("agent: failed to list keys") | |||||
| } | |||||
| panic("unreachable") | |||||
| } | |||||
| // Sign has the agent sign the data using a protocol 2 key as defined | |||||
| // in [PROTOCOL.agent] section 2.6.2. | |||||
| func (c *client) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) { | |||||
| req := ssh.Marshal(signRequestAgentMsg{ | |||||
| KeyBlob: key.Marshal(), | |||||
| Data: data, | |||||
| }) | |||||
| msg, err := c.call(req) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| switch msg := msg.(type) { | |||||
| case *signResponseAgentMsg: | |||||
| var sig ssh.Signature | |||||
| if err := ssh.Unmarshal(msg.SigBlob, &sig); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return &sig, nil | |||||
| case *failureAgentMsg: | |||||
| return nil, errors.New("agent: failed to sign challenge") | |||||
| } | |||||
| panic("unreachable") | |||||
| } | |||||
| // unmarshal parses an agent message in packet, returning the parsed | |||||
| // form and the message type of packet. | |||||
| func unmarshal(packet []byte) (interface{}, error) { | |||||
| if len(packet) < 1 { | |||||
| return nil, errors.New("agent: empty packet") | |||||
| } | |||||
| var msg interface{} | |||||
| switch packet[0] { | |||||
| case agentFailure: | |||||
| return new(failureAgentMsg), nil | |||||
| case agentSuccess: | |||||
| return new(successAgentMsg), nil | |||||
| case agentIdentitiesAnswer: | |||||
| msg = new(identitiesAnswerAgentMsg) | |||||
| case agentSignResponse: | |||||
| msg = new(signResponseAgentMsg) | |||||
| default: | |||||
| return nil, fmt.Errorf("agent: unknown type tag %d", packet[0]) | |||||
| } | |||||
| if err := ssh.Unmarshal(packet, msg); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return msg, nil | |||||
| } | |||||
| type rsaKeyMsg struct { | |||||
| Type string `sshtype:"17"` | |||||
| N *big.Int | |||||
| E *big.Int | |||||
| D *big.Int | |||||
| Iqmp *big.Int // IQMP = Inverse Q Mod P | |||||
| P *big.Int | |||||
| Q *big.Int | |||||
| Comments string | |||||
| Constraints []byte `ssh:"rest"` | |||||
| } | |||||
| type dsaKeyMsg struct { | |||||
| Type string `sshtype:"17"` | |||||
| P *big.Int | |||||
| Q *big.Int | |||||
| G *big.Int | |||||
| Y *big.Int | |||||
| X *big.Int | |||||
| Comments string | |||||
| Constraints []byte `ssh:"rest"` | |||||
| } | |||||
| type ecdsaKeyMsg struct { | |||||
| Type string `sshtype:"17"` | |||||
| Curve string | |||||
| KeyBytes []byte | |||||
| D *big.Int | |||||
| Comments string | |||||
| Constraints []byte `ssh:"rest"` | |||||
| } | |||||
| // Insert adds a private key to the agent. | |||||
| func (c *client) insertKey(s interface{}, comment string, constraints []byte) error { | |||||
| var req []byte | |||||
| switch k := s.(type) { | |||||
| case *rsa.PrivateKey: | |||||
| if len(k.Primes) != 2 { | |||||
| return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes)) | |||||
| } | |||||
| k.Precompute() | |||||
| req = ssh.Marshal(rsaKeyMsg{ | |||||
| Type: ssh.KeyAlgoRSA, | |||||
| N: k.N, | |||||
| E: big.NewInt(int64(k.E)), | |||||
| D: k.D, | |||||
| Iqmp: k.Precomputed.Qinv, | |||||
| P: k.Primes[0], | |||||
| Q: k.Primes[1], | |||||
| Comments: comment, | |||||
| Constraints: constraints, | |||||
| }) | |||||
| case *dsa.PrivateKey: | |||||
| req = ssh.Marshal(dsaKeyMsg{ | |||||
| Type: ssh.KeyAlgoDSA, | |||||
| P: k.P, | |||||
| Q: k.Q, | |||||
| G: k.G, | |||||
| Y: k.Y, | |||||
| X: k.X, | |||||
| Comments: comment, | |||||
| Constraints: constraints, | |||||
| }) | |||||
| case *ecdsa.PrivateKey: | |||||
| nistID := fmt.Sprintf("nistp%d", k.Params().BitSize) | |||||
| req = ssh.Marshal(ecdsaKeyMsg{ | |||||
| Type: "ecdsa-sha2-" + nistID, | |||||
| Curve: nistID, | |||||
| KeyBytes: elliptic.Marshal(k.Curve, k.X, k.Y), | |||||
| D: k.D, | |||||
| Comments: comment, | |||||
| Constraints: constraints, | |||||
| }) | |||||
| default: | |||||
| return fmt.Errorf("agent: unsupported key type %T", s) | |||||
| } | |||||
| // if constraints are present then the message type needs to be changed. | |||||
| if len(constraints) != 0 { | |||||
| req[0] = agentAddIdConstrained | |||||
| } | |||||
| resp, err := c.call(req) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| if _, ok := resp.(*successAgentMsg); ok { | |||||
| return nil | |||||
| } | |||||
| return errors.New("agent: failure") | |||||
| } | |||||
| type rsaCertMsg struct { | |||||
| Type string `sshtype:"17"` | |||||
| CertBytes []byte | |||||
| D *big.Int | |||||
| Iqmp *big.Int // IQMP = Inverse Q Mod P | |||||
| P *big.Int | |||||
| Q *big.Int | |||||
| Comments string | |||||
| Constraints []byte `ssh:"rest"` | |||||
| } | |||||
| type dsaCertMsg struct { | |||||
| Type string `sshtype:"17"` | |||||
| CertBytes []byte | |||||
| X *big.Int | |||||
| Comments string | |||||
| Constraints []byte `ssh:"rest"` | |||||
| } | |||||
| type ecdsaCertMsg struct { | |||||
| Type string `sshtype:"17"` | |||||
| CertBytes []byte | |||||
| D *big.Int | |||||
| Comments string | |||||
| Constraints []byte `ssh:"rest"` | |||||
| } | |||||
| // Insert adds a private key to the agent. If a certificate is given, | |||||
| // that certificate is added instead as public key. | |||||
| func (c *client) Add(key AddedKey) error { | |||||
| var constraints []byte | |||||
| if secs := key.LifetimeSecs; secs != 0 { | |||||
| constraints = append(constraints, agentConstrainLifetime) | |||||
| var secsBytes [4]byte | |||||
| binary.BigEndian.PutUint32(secsBytes[:], secs) | |||||
| constraints = append(constraints, secsBytes[:]...) | |||||
| } | |||||
| if key.ConfirmBeforeUse { | |||||
| constraints = append(constraints, agentConstrainConfirm) | |||||
| } | |||||
| if cert := key.Certificate; cert == nil { | |||||
| return c.insertKey(key.PrivateKey, key.Comment, constraints) | |||||
| } else { | |||||
| return c.insertCert(key.PrivateKey, cert, key.Comment, constraints) | |||||
| } | |||||
| } | |||||
| func (c *client) insertCert(s interface{}, cert *ssh.Certificate, comment string, constraints []byte) error { | |||||
| var req []byte | |||||
| switch k := s.(type) { | |||||
| case *rsa.PrivateKey: | |||||
| if len(k.Primes) != 2 { | |||||
| return fmt.Errorf("agent: unsupported RSA key with %d primes", len(k.Primes)) | |||||
| } | |||||
| k.Precompute() | |||||
| req = ssh.Marshal(rsaCertMsg{ | |||||
| Type: cert.Type(), | |||||
| CertBytes: cert.Marshal(), | |||||
| D: k.D, | |||||
| Iqmp: k.Precomputed.Qinv, | |||||
| P: k.Primes[0], | |||||
| Q: k.Primes[1], | |||||
| Comments: comment, | |||||
| Constraints: constraints, | |||||
| }) | |||||
| case *dsa.PrivateKey: | |||||
| req = ssh.Marshal(dsaCertMsg{ | |||||
| Type: cert.Type(), | |||||
| CertBytes: cert.Marshal(), | |||||
| X: k.X, | |||||
| Comments: comment, | |||||
| }) | |||||
| case *ecdsa.PrivateKey: | |||||
| req = ssh.Marshal(ecdsaCertMsg{ | |||||
| Type: cert.Type(), | |||||
| CertBytes: cert.Marshal(), | |||||
| D: k.D, | |||||
| Comments: comment, | |||||
| }) | |||||
| default: | |||||
| return fmt.Errorf("agent: unsupported key type %T", s) | |||||
| } | |||||
| // if constraints are present then the message type needs to be changed. | |||||
| if len(constraints) != 0 { | |||||
| req[0] = agentAddIdConstrained | |||||
| } | |||||
| signer, err := ssh.NewSignerFromKey(s) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { | |||||
| return errors.New("agent: signer and cert have different public key") | |||||
| } | |||||
| resp, err := c.call(req) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| if _, ok := resp.(*successAgentMsg); ok { | |||||
| return nil | |||||
| } | |||||
| return errors.New("agent: failure") | |||||
| } | |||||
| // Signers provides a callback for client authentication. | |||||
| func (c *client) Signers() ([]ssh.Signer, error) { | |||||
| keys, err := c.List() | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| var result []ssh.Signer | |||||
| for _, k := range keys { | |||||
| result = append(result, &agentKeyringSigner{c, k}) | |||||
| } | |||||
| return result, nil | |||||
| } | |||||
| type agentKeyringSigner struct { | |||||
| agent *client | |||||
| pub ssh.PublicKey | |||||
| } | |||||
| func (s *agentKeyringSigner) PublicKey() ssh.PublicKey { | |||||
| return s.pub | |||||
| } | |||||
| func (s *agentKeyringSigner) Sign(rand io.Reader, data []byte) (*ssh.Signature, error) { | |||||
| // The agent has its own entropy source, so the rand argument is ignored. | |||||
| return s.agent.Sign(s.pub, data) | |||||
| } | |||||
| @@ -1,287 +0,0 @@ | |||||
| // Copyright 2012 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package agent | |||||
| import ( | |||||
| "bytes" | |||||
| "crypto/rand" | |||||
| "errors" | |||||
| "net" | |||||
| "os" | |||||
| "os/exec" | |||||
| "path/filepath" | |||||
| "strconv" | |||||
| "testing" | |||||
| "github.com/gogits/gogs/modules/crypto/ssh" | |||||
| ) | |||||
| // startAgent executes ssh-agent, and returns a Agent interface to it. | |||||
| func startAgent(t *testing.T) (client Agent, socket string, cleanup func()) { | |||||
| if testing.Short() { | |||||
| // ssh-agent is not always available, and the key | |||||
| // types supported vary by platform. | |||||
| t.Skip("skipping test due to -short") | |||||
| } | |||||
| bin, err := exec.LookPath("ssh-agent") | |||||
| if err != nil { | |||||
| t.Skip("could not find ssh-agent") | |||||
| } | |||||
| cmd := exec.Command(bin, "-s") | |||||
| out, err := cmd.Output() | |||||
| if err != nil { | |||||
| t.Fatalf("cmd.Output: %v", err) | |||||
| } | |||||
| /* Output looks like: | |||||
| SSH_AUTH_SOCK=/tmp/ssh-P65gpcqArqvH/agent.15541; export SSH_AUTH_SOCK; | |||||
| SSH_AGENT_PID=15542; export SSH_AGENT_PID; | |||||
| echo Agent pid 15542; | |||||
| */ | |||||
| fields := bytes.Split(out, []byte(";")) | |||||
| line := bytes.SplitN(fields[0], []byte("="), 2) | |||||
| line[0] = bytes.TrimLeft(line[0], "\n") | |||||
| if string(line[0]) != "SSH_AUTH_SOCK" { | |||||
| t.Fatalf("could not find key SSH_AUTH_SOCK in %q", fields[0]) | |||||
| } | |||||
| socket = string(line[1]) | |||||
| line = bytes.SplitN(fields[2], []byte("="), 2) | |||||
| line[0] = bytes.TrimLeft(line[0], "\n") | |||||
| if string(line[0]) != "SSH_AGENT_PID" { | |||||
| t.Fatalf("could not find key SSH_AGENT_PID in %q", fields[2]) | |||||
| } | |||||
| pidStr := line[1] | |||||
| pid, err := strconv.Atoi(string(pidStr)) | |||||
| if err != nil { | |||||
| t.Fatalf("Atoi(%q): %v", pidStr, err) | |||||
| } | |||||
| conn, err := net.Dial("unix", string(socket)) | |||||
| if err != nil { | |||||
| t.Fatalf("net.Dial: %v", err) | |||||
| } | |||||
| ac := NewClient(conn) | |||||
| return ac, socket, func() { | |||||
| proc, _ := os.FindProcess(pid) | |||||
| if proc != nil { | |||||
| proc.Kill() | |||||
| } | |||||
| conn.Close() | |||||
| os.RemoveAll(filepath.Dir(socket)) | |||||
| } | |||||
| } | |||||
| func testAgent(t *testing.T, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) { | |||||
| agent, _, cleanup := startAgent(t) | |||||
| defer cleanup() | |||||
| testAgentInterface(t, agent, key, cert, lifetimeSecs) | |||||
| } | |||||
| func testAgentInterface(t *testing.T, agent Agent, key interface{}, cert *ssh.Certificate, lifetimeSecs uint32) { | |||||
| signer, err := ssh.NewSignerFromKey(key) | |||||
| if err != nil { | |||||
| t.Fatalf("NewSignerFromKey(%T): %v", key, err) | |||||
| } | |||||
| // The agent should start up empty. | |||||
| if keys, err := agent.List(); err != nil { | |||||
| t.Fatalf("RequestIdentities: %v", err) | |||||
| } else if len(keys) > 0 { | |||||
| t.Fatalf("got %d keys, want 0: %v", len(keys), keys) | |||||
| } | |||||
| // Attempt to insert the key, with certificate if specified. | |||||
| var pubKey ssh.PublicKey | |||||
| if cert != nil { | |||||
| err = agent.Add(AddedKey{ | |||||
| PrivateKey: key, | |||||
| Certificate: cert, | |||||
| Comment: "comment", | |||||
| LifetimeSecs: lifetimeSecs, | |||||
| }) | |||||
| pubKey = cert | |||||
| } else { | |||||
| err = agent.Add(AddedKey{PrivateKey: key, Comment: "comment", LifetimeSecs: lifetimeSecs}) | |||||
| pubKey = signer.PublicKey() | |||||
| } | |||||
| if err != nil { | |||||
| t.Fatalf("insert(%T): %v", key, err) | |||||
| } | |||||
| // Did the key get inserted successfully? | |||||
| if keys, err := agent.List(); err != nil { | |||||
| t.Fatalf("List: %v", err) | |||||
| } else if len(keys) != 1 { | |||||
| t.Fatalf("got %v, want 1 key", keys) | |||||
| } else if keys[0].Comment != "comment" { | |||||
| t.Fatalf("key comment: got %v, want %v", keys[0].Comment, "comment") | |||||
| } else if !bytes.Equal(keys[0].Blob, pubKey.Marshal()) { | |||||
| t.Fatalf("key mismatch") | |||||
| } | |||||
| // Can the agent make a valid signature? | |||||
| data := []byte("hello") | |||||
| sig, err := agent.Sign(pubKey, data) | |||||
| if err != nil { | |||||
| t.Fatalf("Sign(%s): %v", pubKey.Type(), err) | |||||
| } | |||||
| if err := pubKey.Verify(data, sig); err != nil { | |||||
| t.Fatalf("Verify(%s): %v", pubKey.Type(), err) | |||||
| } | |||||
| } | |||||
| func TestAgent(t *testing.T) { | |||||
| for _, keyType := range []string{"rsa", "dsa", "ecdsa"} { | |||||
| testAgent(t, testPrivateKeys[keyType], nil, 0) | |||||
| } | |||||
| } | |||||
| func TestCert(t *testing.T) { | |||||
| cert := &ssh.Certificate{ | |||||
| Key: testPublicKeys["rsa"], | |||||
| ValidBefore: ssh.CertTimeInfinity, | |||||
| CertType: ssh.UserCert, | |||||
| } | |||||
| cert.SignCert(rand.Reader, testSigners["ecdsa"]) | |||||
| testAgent(t, testPrivateKeys["rsa"], cert, 0) | |||||
| } | |||||
| func TestConstraints(t *testing.T) { | |||||
| testAgent(t, testPrivateKeys["rsa"], nil, 3600 /* lifetime in seconds */) | |||||
| } | |||||
| // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and | |||||
| // therefore is buffered (net.Pipe deadlocks if both sides start with | |||||
| // a write.) | |||||
| func netPipe() (net.Conn, net.Conn, error) { | |||||
| listener, err := net.Listen("tcp", "127.0.0.1:0") | |||||
| if err != nil { | |||||
| return nil, nil, err | |||||
| } | |||||
| defer listener.Close() | |||||
| c1, err := net.Dial("tcp", listener.Addr().String()) | |||||
| if err != nil { | |||||
| return nil, nil, err | |||||
| } | |||||
| c2, err := listener.Accept() | |||||
| if err != nil { | |||||
| c1.Close() | |||||
| return nil, nil, err | |||||
| } | |||||
| return c1, c2, nil | |||||
| } | |||||
| func TestAuth(t *testing.T) { | |||||
| a, b, err := netPipe() | |||||
| if err != nil { | |||||
| t.Fatalf("netPipe: %v", err) | |||||
| } | |||||
| defer a.Close() | |||||
| defer b.Close() | |||||
| agent, _, cleanup := startAgent(t) | |||||
| defer cleanup() | |||||
| if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment"}); err != nil { | |||||
| t.Errorf("Add: %v", err) | |||||
| } | |||||
| serverConf := ssh.ServerConfig{} | |||||
| serverConf.AddHostKey(testSigners["rsa"]) | |||||
| serverConf.PublicKeyCallback = func(c ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { | |||||
| if bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { | |||||
| return nil, nil | |||||
| } | |||||
| return nil, errors.New("pubkey rejected") | |||||
| } | |||||
| go func() { | |||||
| conn, _, _, err := ssh.NewServerConn(a, &serverConf) | |||||
| if err != nil { | |||||
| t.Fatalf("Server: %v", err) | |||||
| } | |||||
| conn.Close() | |||||
| }() | |||||
| conf := ssh.ClientConfig{} | |||||
| conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers)) | |||||
| conn, _, _, err := ssh.NewClientConn(b, "", &conf) | |||||
| if err != nil { | |||||
| t.Fatalf("NewClientConn: %v", err) | |||||
| } | |||||
| conn.Close() | |||||
| } | |||||
| func TestLockClient(t *testing.T) { | |||||
| agent, _, cleanup := startAgent(t) | |||||
| defer cleanup() | |||||
| testLockAgent(agent, t) | |||||
| } | |||||
| func testLockAgent(agent Agent, t *testing.T) { | |||||
| if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment 1"}); err != nil { | |||||
| t.Errorf("Add: %v", err) | |||||
| } | |||||
| if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["dsa"], Comment: "comment dsa"}); err != nil { | |||||
| t.Errorf("Add: %v", err) | |||||
| } | |||||
| if keys, err := agent.List(); err != nil { | |||||
| t.Errorf("List: %v", err) | |||||
| } else if len(keys) != 2 { | |||||
| t.Errorf("Want 2 keys, got %v", keys) | |||||
| } | |||||
| passphrase := []byte("secret") | |||||
| if err := agent.Lock(passphrase); err != nil { | |||||
| t.Errorf("Lock: %v", err) | |||||
| } | |||||
| if keys, err := agent.List(); err != nil { | |||||
| t.Errorf("List: %v", err) | |||||
| } else if len(keys) != 0 { | |||||
| t.Errorf("Want 0 keys, got %v", keys) | |||||
| } | |||||
| signer, _ := ssh.NewSignerFromKey(testPrivateKeys["rsa"]) | |||||
| if _, err := agent.Sign(signer.PublicKey(), []byte("hello")); err == nil { | |||||
| t.Fatalf("Sign did not fail") | |||||
| } | |||||
| if err := agent.Remove(signer.PublicKey()); err == nil { | |||||
| t.Fatalf("Remove did not fail") | |||||
| } | |||||
| if err := agent.RemoveAll(); err == nil { | |||||
| t.Fatalf("RemoveAll did not fail") | |||||
| } | |||||
| if err := agent.Unlock(nil); err == nil { | |||||
| t.Errorf("Unlock with wrong passphrase succeeded") | |||||
| } | |||||
| if err := agent.Unlock(passphrase); err != nil { | |||||
| t.Errorf("Unlock: %v", err) | |||||
| } | |||||
| if err := agent.Remove(signer.PublicKey()); err != nil { | |||||
| t.Fatalf("Remove: %v", err) | |||||
| } | |||||
| if keys, err := agent.List(); err != nil { | |||||
| t.Errorf("List: %v", err) | |||||
| } else if len(keys) != 1 { | |||||
| t.Errorf("Want 1 keys, got %v", keys) | |||||
| } | |||||
| } | |||||
| @@ -1,103 +0,0 @@ | |||||
| // Copyright 2014 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package agent | |||||
| import ( | |||||
| "errors" | |||||
| "io" | |||||
| "net" | |||||
| "sync" | |||||
| "github.com/gogits/gogs/modules/crypto/ssh" | |||||
| ) | |||||
| // RequestAgentForwarding sets up agent forwarding for the session. | |||||
| // ForwardToAgent or ForwardToRemote should be called to route | |||||
| // the authentication requests. | |||||
| func RequestAgentForwarding(session *ssh.Session) error { | |||||
| ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| if !ok { | |||||
| return errors.New("forwarding request denied") | |||||
| } | |||||
| return nil | |||||
| } | |||||
| // ForwardToAgent routes authentication requests to the given keyring. | |||||
| func ForwardToAgent(client *ssh.Client, keyring Agent) error { | |||||
| channels := client.HandleChannelOpen(channelType) | |||||
| if channels == nil { | |||||
| return errors.New("agent: already have handler for " + channelType) | |||||
| } | |||||
| go func() { | |||||
| for ch := range channels { | |||||
| channel, reqs, err := ch.Accept() | |||||
| if err != nil { | |||||
| continue | |||||
| } | |||||
| go ssh.DiscardRequests(reqs) | |||||
| go func() { | |||||
| ServeAgent(keyring, channel) | |||||
| channel.Close() | |||||
| }() | |||||
| } | |||||
| }() | |||||
| return nil | |||||
| } | |||||
| const channelType = "auth-agent@openssh.com" | |||||
| // ForwardToRemote routes authentication requests to the ssh-agent | |||||
| // process serving on the given unix socket. | |||||
| func ForwardToRemote(client *ssh.Client, addr string) error { | |||||
| channels := client.HandleChannelOpen(channelType) | |||||
| if channels == nil { | |||||
| return errors.New("agent: already have handler for " + channelType) | |||||
| } | |||||
| conn, err := net.Dial("unix", addr) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| conn.Close() | |||||
| go func() { | |||||
| for ch := range channels { | |||||
| channel, reqs, err := ch.Accept() | |||||
| if err != nil { | |||||
| continue | |||||
| } | |||||
| go ssh.DiscardRequests(reqs) | |||||
| go forwardUnixSocket(channel, addr) | |||||
| } | |||||
| }() | |||||
| return nil | |||||
| } | |||||
| func forwardUnixSocket(channel ssh.Channel, addr string) { | |||||
| conn, err := net.Dial("unix", addr) | |||||
| if err != nil { | |||||
| return | |||||
| } | |||||
| var wg sync.WaitGroup | |||||
| wg.Add(2) | |||||
| go func() { | |||||
| io.Copy(conn, channel) | |||||
| conn.(*net.UnixConn).CloseWrite() | |||||
| wg.Done() | |||||
| }() | |||||
| go func() { | |||||
| io.Copy(channel, conn) | |||||
| channel.CloseWrite() | |||||
| wg.Done() | |||||
| }() | |||||
| wg.Wait() | |||||
| conn.Close() | |||||
| channel.Close() | |||||
| } | |||||
| @@ -1,184 +0,0 @@ | |||||
| // Copyright 2014 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package agent | |||||
| import ( | |||||
| "bytes" | |||||
| "crypto/rand" | |||||
| "crypto/subtle" | |||||
| "errors" | |||||
| "fmt" | |||||
| "sync" | |||||
| "github.com/gogits/gogs/modules/crypto/ssh" | |||||
| ) | |||||
| type privKey struct { | |||||
| signer ssh.Signer | |||||
| comment string | |||||
| } | |||||
| type keyring struct { | |||||
| mu sync.Mutex | |||||
| keys []privKey | |||||
| locked bool | |||||
| passphrase []byte | |||||
| } | |||||
| var errLocked = errors.New("agent: locked") | |||||
| // NewKeyring returns an Agent that holds keys in memory. It is safe | |||||
| // for concurrent use by multiple goroutines. | |||||
| func NewKeyring() Agent { | |||||
| return &keyring{} | |||||
| } | |||||
| // RemoveAll removes all identities. | |||||
| func (r *keyring) RemoveAll() error { | |||||
| r.mu.Lock() | |||||
| defer r.mu.Unlock() | |||||
| if r.locked { | |||||
| return errLocked | |||||
| } | |||||
| r.keys = nil | |||||
| return nil | |||||
| } | |||||
| // Remove removes all identities with the given public key. | |||||
| func (r *keyring) Remove(key ssh.PublicKey) error { | |||||
| r.mu.Lock() | |||||
| defer r.mu.Unlock() | |||||
| if r.locked { | |||||
| return errLocked | |||||
| } | |||||
| want := key.Marshal() | |||||
| found := false | |||||
| for i := 0; i < len(r.keys); { | |||||
| if bytes.Equal(r.keys[i].signer.PublicKey().Marshal(), want) { | |||||
| found = true | |||||
| r.keys[i] = r.keys[len(r.keys)-1] | |||||
| r.keys = r.keys[len(r.keys)-1:] | |||||
| continue | |||||
| } else { | |||||
| i++ | |||||
| } | |||||
| } | |||||
| if !found { | |||||
| return errors.New("agent: key not found") | |||||
| } | |||||
| return nil | |||||
| } | |||||
| // Lock locks the agent. Sign and Remove will fail, and List will empty an empty list. | |||||
| func (r *keyring) Lock(passphrase []byte) error { | |||||
| r.mu.Lock() | |||||
| defer r.mu.Unlock() | |||||
| if r.locked { | |||||
| return errLocked | |||||
| } | |||||
| r.locked = true | |||||
| r.passphrase = passphrase | |||||
| return nil | |||||
| } | |||||
| // Unlock undoes the effect of Lock | |||||
| func (r *keyring) Unlock(passphrase []byte) error { | |||||
| r.mu.Lock() | |||||
| defer r.mu.Unlock() | |||||
| if !r.locked { | |||||
| return errors.New("agent: not locked") | |||||
| } | |||||
| if len(passphrase) != len(r.passphrase) || 1 != subtle.ConstantTimeCompare(passphrase, r.passphrase) { | |||||
| return fmt.Errorf("agent: incorrect passphrase") | |||||
| } | |||||
| r.locked = false | |||||
| r.passphrase = nil | |||||
| return nil | |||||
| } | |||||
| // List returns the identities known to the agent. | |||||
| func (r *keyring) List() ([]*Key, error) { | |||||
| r.mu.Lock() | |||||
| defer r.mu.Unlock() | |||||
| if r.locked { | |||||
| // section 2.7: locked agents return empty. | |||||
| return nil, nil | |||||
| } | |||||
| var ids []*Key | |||||
| for _, k := range r.keys { | |||||
| pub := k.signer.PublicKey() | |||||
| ids = append(ids, &Key{ | |||||
| Format: pub.Type(), | |||||
| Blob: pub.Marshal(), | |||||
| Comment: k.comment}) | |||||
| } | |||||
| return ids, nil | |||||
| } | |||||
| // Insert adds a private key to the keyring. If a certificate | |||||
| // is given, that certificate is added as public key. Note that | |||||
| // any constraints given are ignored. | |||||
| func (r *keyring) Add(key AddedKey) error { | |||||
| r.mu.Lock() | |||||
| defer r.mu.Unlock() | |||||
| if r.locked { | |||||
| return errLocked | |||||
| } | |||||
| signer, err := ssh.NewSignerFromKey(key.PrivateKey) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| if cert := key.Certificate; cert != nil { | |||||
| signer, err = ssh.NewCertSigner(cert, signer) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| } | |||||
| r.keys = append(r.keys, privKey{signer, key.Comment}) | |||||
| return nil | |||||
| } | |||||
| // Sign returns a signature for the data. | |||||
| func (r *keyring) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) { | |||||
| r.mu.Lock() | |||||
| defer r.mu.Unlock() | |||||
| if r.locked { | |||||
| return nil, errLocked | |||||
| } | |||||
| wanted := key.Marshal() | |||||
| for _, k := range r.keys { | |||||
| if bytes.Equal(k.signer.PublicKey().Marshal(), wanted) { | |||||
| return k.signer.Sign(rand.Reader, data) | |||||
| } | |||||
| } | |||||
| return nil, errors.New("not found") | |||||
| } | |||||
| // Signers returns signers for all the known keys. | |||||
| func (r *keyring) Signers() ([]ssh.Signer, error) { | |||||
| r.mu.Lock() | |||||
| defer r.mu.Unlock() | |||||
| if r.locked { | |||||
| return nil, errLocked | |||||
| } | |||||
| s := make([]ssh.Signer, 0, len(r.keys)) | |||||
| for _, k := range r.keys { | |||||
| s = append(s, k.signer) | |||||
| } | |||||
| return s, nil | |||||
| } | |||||
| @@ -1,209 +0,0 @@ | |||||
| // Copyright 2012 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package agent | |||||
| import ( | |||||
| "crypto/rsa" | |||||
| "encoding/binary" | |||||
| "fmt" | |||||
| "io" | |||||
| "log" | |||||
| "math/big" | |||||
| "github.com/gogits/gogs/modules/crypto/ssh" | |||||
| ) | |||||
| // Server wraps an Agent and uses it to implement the agent side of | |||||
| // the SSH-agent, wire protocol. | |||||
| type server struct { | |||||
| agent Agent | |||||
| } | |||||
| func (s *server) processRequestBytes(reqData []byte) []byte { | |||||
| rep, err := s.processRequest(reqData) | |||||
| if err != nil { | |||||
| if err != errLocked { | |||||
| // TODO(hanwen): provide better logging interface? | |||||
| log.Printf("agent %d: %v", reqData[0], err) | |||||
| } | |||||
| return []byte{agentFailure} | |||||
| } | |||||
| if err == nil && rep == nil { | |||||
| return []byte{agentSuccess} | |||||
| } | |||||
| return ssh.Marshal(rep) | |||||
| } | |||||
| func marshalKey(k *Key) []byte { | |||||
| var record struct { | |||||
| Blob []byte | |||||
| Comment string | |||||
| } | |||||
| record.Blob = k.Marshal() | |||||
| record.Comment = k.Comment | |||||
| return ssh.Marshal(&record) | |||||
| } | |||||
| type agentV1IdentityMsg struct { | |||||
| Numkeys uint32 `sshtype:"2"` | |||||
| } | |||||
| type agentRemoveIdentityMsg struct { | |||||
| KeyBlob []byte `sshtype:"18"` | |||||
| } | |||||
| type agentLockMsg struct { | |||||
| Passphrase []byte `sshtype:"22"` | |||||
| } | |||||
| type agentUnlockMsg struct { | |||||
| Passphrase []byte `sshtype:"23"` | |||||
| } | |||||
| func (s *server) processRequest(data []byte) (interface{}, error) { | |||||
| switch data[0] { | |||||
| case agentRequestV1Identities: | |||||
| return &agentV1IdentityMsg{0}, nil | |||||
| case agentRemoveIdentity: | |||||
| var req agentRemoveIdentityMsg | |||||
| if err := ssh.Unmarshal(data, &req); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| var wk wireKey | |||||
| if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob}) | |||||
| case agentRemoveAllIdentities: | |||||
| return nil, s.agent.RemoveAll() | |||||
| case agentLock: | |||||
| var req agentLockMsg | |||||
| if err := ssh.Unmarshal(data, &req); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return nil, s.agent.Lock(req.Passphrase) | |||||
| case agentUnlock: | |||||
| var req agentLockMsg | |||||
| if err := ssh.Unmarshal(data, &req); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return nil, s.agent.Unlock(req.Passphrase) | |||||
| case agentSignRequest: | |||||
| var req signRequestAgentMsg | |||||
| if err := ssh.Unmarshal(data, &req); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| var wk wireKey | |||||
| if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| k := &Key{ | |||||
| Format: wk.Format, | |||||
| Blob: req.KeyBlob, | |||||
| } | |||||
| sig, err := s.agent.Sign(k, req.Data) // TODO(hanwen): flags. | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil | |||||
| case agentRequestIdentities: | |||||
| keys, err := s.agent.List() | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| rep := identitiesAnswerAgentMsg{ | |||||
| NumKeys: uint32(len(keys)), | |||||
| } | |||||
| for _, k := range keys { | |||||
| rep.Keys = append(rep.Keys, marshalKey(k)...) | |||||
| } | |||||
| return rep, nil | |||||
| case agentAddIdentity: | |||||
| return nil, s.insertIdentity(data) | |||||
| } | |||||
| return nil, fmt.Errorf("unknown opcode %d", data[0]) | |||||
| } | |||||
| func (s *server) insertIdentity(req []byte) error { | |||||
| var record struct { | |||||
| Type string `sshtype:"17"` | |||||
| Rest []byte `ssh:"rest"` | |||||
| } | |||||
| if err := ssh.Unmarshal(req, &record); err != nil { | |||||
| return err | |||||
| } | |||||
| switch record.Type { | |||||
| case ssh.KeyAlgoRSA: | |||||
| var k rsaKeyMsg | |||||
| if err := ssh.Unmarshal(req, &k); err != nil { | |||||
| return err | |||||
| } | |||||
| priv := rsa.PrivateKey{ | |||||
| PublicKey: rsa.PublicKey{ | |||||
| E: int(k.E.Int64()), | |||||
| N: k.N, | |||||
| }, | |||||
| D: k.D, | |||||
| Primes: []*big.Int{k.P, k.Q}, | |||||
| } | |||||
| priv.Precompute() | |||||
| return s.agent.Add(AddedKey{PrivateKey: &priv, Comment: k.Comments}) | |||||
| } | |||||
| return fmt.Errorf("not implemented: %s", record.Type) | |||||
| } | |||||
| // ServeAgent serves the agent protocol on the given connection. It | |||||
| // returns when an I/O error occurs. | |||||
| func ServeAgent(agent Agent, c io.ReadWriter) error { | |||||
| s := &server{agent} | |||||
| var length [4]byte | |||||
| for { | |||||
| if _, err := io.ReadFull(c, length[:]); err != nil { | |||||
| return err | |||||
| } | |||||
| l := binary.BigEndian.Uint32(length[:]) | |||||
| if l > maxAgentResponseBytes { | |||||
| // We also cap requests. | |||||
| return fmt.Errorf("agent: request too large: %d", l) | |||||
| } | |||||
| req := make([]byte, l) | |||||
| if _, err := io.ReadFull(c, req); err != nil { | |||||
| return err | |||||
| } | |||||
| repData := s.processRequestBytes(req) | |||||
| if len(repData) > maxAgentResponseBytes { | |||||
| return fmt.Errorf("agent: reply too large: %d bytes", len(repData)) | |||||
| } | |||||
| binary.BigEndian.PutUint32(length[:], uint32(len(repData))) | |||||
| if _, err := c.Write(length[:]); err != nil { | |||||
| return err | |||||
| } | |||||
| if _, err := c.Write(repData); err != nil { | |||||
| return err | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,77 +0,0 @@ | |||||
| // Copyright 2012 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package agent | |||||
| import ( | |||||
| "testing" | |||||
| "github.com/gogits/gogs/modules/crypto/ssh" | |||||
| ) | |||||
| func TestServer(t *testing.T) { | |||||
| c1, c2, err := netPipe() | |||||
| if err != nil { | |||||
| t.Fatalf("netPipe: %v", err) | |||||
| } | |||||
| defer c1.Close() | |||||
| defer c2.Close() | |||||
| client := NewClient(c1) | |||||
| go ServeAgent(NewKeyring(), c2) | |||||
| testAgentInterface(t, client, testPrivateKeys["rsa"], nil, 0) | |||||
| } | |||||
| func TestLockServer(t *testing.T) { | |||||
| testLockAgent(NewKeyring(), t) | |||||
| } | |||||
| func TestSetupForwardAgent(t *testing.T) { | |||||
| a, b, err := netPipe() | |||||
| if err != nil { | |||||
| t.Fatalf("netPipe: %v", err) | |||||
| } | |||||
| defer a.Close() | |||||
| defer b.Close() | |||||
| _, socket, cleanup := startAgent(t) | |||||
| defer cleanup() | |||||
| serverConf := ssh.ServerConfig{ | |||||
| NoClientAuth: true, | |||||
| } | |||||
| serverConf.AddHostKey(testSigners["rsa"]) | |||||
| incoming := make(chan *ssh.ServerConn, 1) | |||||
| go func() { | |||||
| conn, _, _, err := ssh.NewServerConn(a, &serverConf) | |||||
| if err != nil { | |||||
| t.Fatalf("Server: %v", err) | |||||
| } | |||||
| incoming <- conn | |||||
| }() | |||||
| conf := ssh.ClientConfig{} | |||||
| conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf) | |||||
| if err != nil { | |||||
| t.Fatalf("NewClientConn: %v", err) | |||||
| } | |||||
| client := ssh.NewClient(conn, chans, reqs) | |||||
| if err := ForwardToRemote(client, socket); err != nil { | |||||
| t.Fatalf("SetupForwardAgent: %v", err) | |||||
| } | |||||
| server := <-incoming | |||||
| ch, reqs, err := server.OpenChannel(channelType, nil) | |||||
| if err != nil { | |||||
| t.Fatalf("OpenChannel(%q): %v", channelType, err) | |||||
| } | |||||
| go ssh.DiscardRequests(reqs) | |||||
| agentClient := NewClient(ch) | |||||
| testAgentInterface(t, agentClient, testPrivateKeys["rsa"], nil, 0) | |||||
| conn.Close() | |||||
| } | |||||
| @@ -1,64 +0,0 @@ | |||||
| // Copyright 2014 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| // IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places: | |||||
| // ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three | |||||
| // instances. | |||||
| package agent | |||||
| import ( | |||||
| "crypto/rand" | |||||
| "fmt" | |||||
| "github.com/gogits/gogs/modules/crypto/ssh" | |||||
| "github.com/gogits/gogs/modules/crypto/ssh/testdata" | |||||
| ) | |||||
| var ( | |||||
| testPrivateKeys map[string]interface{} | |||||
| testSigners map[string]ssh.Signer | |||||
| testPublicKeys map[string]ssh.PublicKey | |||||
| ) | |||||
| func init() { | |||||
| var err error | |||||
| n := len(testdata.PEMBytes) | |||||
| testPrivateKeys = make(map[string]interface{}, n) | |||||
| testSigners = make(map[string]ssh.Signer, n) | |||||
| testPublicKeys = make(map[string]ssh.PublicKey, n) | |||||
| for t, k := range testdata.PEMBytes { | |||||
| testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k) | |||||
| if err != nil { | |||||
| panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) | |||||
| } | |||||
| testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t]) | |||||
| if err != nil { | |||||
| panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) | |||||
| } | |||||
| testPublicKeys[t] = testSigners[t].PublicKey() | |||||
| } | |||||
| // Create a cert and sign it for use in tests. | |||||
| testCert := &ssh.Certificate{ | |||||
| Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil | |||||
| ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage | |||||
| ValidAfter: 0, // unix epoch | |||||
| ValidBefore: ssh.CertTimeInfinity, // The end of currently representable time. | |||||
| Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil | |||||
| Key: testPublicKeys["ecdsa"], | |||||
| SignatureKey: testPublicKeys["rsa"], | |||||
| Permissions: ssh.Permissions{ | |||||
| CriticalOptions: map[string]string{}, | |||||
| Extensions: map[string]string{}, | |||||
| }, | |||||
| } | |||||
| testCert.SignCert(rand.Reader, testSigners["rsa"]) | |||||
| testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] | |||||
| testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"]) | |||||
| if err != nil { | |||||
| panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) | |||||
| } | |||||
| } | |||||
| @@ -1,122 +0,0 @@ | |||||
| // Copyright 2013 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "errors" | |||||
| "io" | |||||
| "net" | |||||
| "testing" | |||||
| ) | |||||
| type server struct { | |||||
| *ServerConn | |||||
| chans <-chan NewChannel | |||||
| } | |||||
| func newServer(c net.Conn, conf *ServerConfig) (*server, error) { | |||||
| sconn, chans, reqs, err := NewServerConn(c, conf) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| go DiscardRequests(reqs) | |||||
| return &server{sconn, chans}, nil | |||||
| } | |||||
| func (s *server) Accept() (NewChannel, error) { | |||||
| n, ok := <-s.chans | |||||
| if !ok { | |||||
| return nil, io.EOF | |||||
| } | |||||
| return n, nil | |||||
| } | |||||
| func sshPipe() (Conn, *server, error) { | |||||
| c1, c2, err := netPipe() | |||||
| if err != nil { | |||||
| return nil, nil, err | |||||
| } | |||||
| clientConf := ClientConfig{ | |||||
| User: "user", | |||||
| } | |||||
| serverConf := ServerConfig{ | |||||
| NoClientAuth: true, | |||||
| } | |||||
| serverConf.AddHostKey(testSigners["ecdsa"]) | |||||
| done := make(chan *server, 1) | |||||
| go func() { | |||||
| server, err := newServer(c2, &serverConf) | |||||
| if err != nil { | |||||
| done <- nil | |||||
| } | |||||
| done <- server | |||||
| }() | |||||
| client, _, reqs, err := NewClientConn(c1, "", &clientConf) | |||||
| if err != nil { | |||||
| return nil, nil, err | |||||
| } | |||||
| server := <-done | |||||
| if server == nil { | |||||
| return nil, nil, errors.New("server handshake failed.") | |||||
| } | |||||
| go DiscardRequests(reqs) | |||||
| return client, server, nil | |||||
| } | |||||
| func BenchmarkEndToEnd(b *testing.B) { | |||||
| b.StopTimer() | |||||
| client, server, err := sshPipe() | |||||
| if err != nil { | |||||
| b.Fatalf("sshPipe: %v", err) | |||||
| } | |||||
| defer client.Close() | |||||
| defer server.Close() | |||||
| size := (1 << 20) | |||||
| input := make([]byte, size) | |||||
| output := make([]byte, size) | |||||
| b.SetBytes(int64(size)) | |||||
| done := make(chan int, 1) | |||||
| go func() { | |||||
| newCh, err := server.Accept() | |||||
| if err != nil { | |||||
| b.Fatalf("Client: %v", err) | |||||
| } | |||||
| ch, incoming, err := newCh.Accept() | |||||
| go DiscardRequests(incoming) | |||||
| for i := 0; i < b.N; i++ { | |||||
| if _, err := io.ReadFull(ch, output); err != nil { | |||||
| b.Fatalf("ReadFull: %v", err) | |||||
| } | |||||
| } | |||||
| ch.Close() | |||||
| done <- 1 | |||||
| }() | |||||
| ch, in, err := client.OpenChannel("speed", nil) | |||||
| if err != nil { | |||||
| b.Fatalf("OpenChannel: %v", err) | |||||
| } | |||||
| go DiscardRequests(in) | |||||
| b.ResetTimer() | |||||
| b.StartTimer() | |||||
| for i := 0; i < b.N; i++ { | |||||
| if _, err := ch.Write(input); err != nil { | |||||
| b.Fatalf("WriteFull: %v", err) | |||||
| } | |||||
| } | |||||
| ch.Close() | |||||
| b.StopTimer() | |||||
| <-done | |||||
| } | |||||
| @@ -1,98 +0,0 @@ | |||||
| // Copyright 2012 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "io" | |||||
| "sync" | |||||
| ) | |||||
| // buffer provides a linked list buffer for data exchange | |||||
| // between producer and consumer. Theoretically the buffer is | |||||
| // of unlimited capacity as it does no allocation of its own. | |||||
| type buffer struct { | |||||
| // protects concurrent access to head, tail and closed | |||||
| *sync.Cond | |||||
| head *element // the buffer that will be read first | |||||
| tail *element // the buffer that will be read last | |||||
| closed bool | |||||
| } | |||||
| // An element represents a single link in a linked list. | |||||
| type element struct { | |||||
| buf []byte | |||||
| next *element | |||||
| } | |||||
| // newBuffer returns an empty buffer that is not closed. | |||||
| func newBuffer() *buffer { | |||||
| e := new(element) | |||||
| b := &buffer{ | |||||
| Cond: newCond(), | |||||
| head: e, | |||||
| tail: e, | |||||
| } | |||||
| return b | |||||
| } | |||||
| // write makes buf available for Read to receive. | |||||
| // buf must not be modified after the call to write. | |||||
| func (b *buffer) write(buf []byte) { | |||||
| b.Cond.L.Lock() | |||||
| e := &element{buf: buf} | |||||
| b.tail.next = e | |||||
| b.tail = e | |||||
| b.Cond.Signal() | |||||
| b.Cond.L.Unlock() | |||||
| } | |||||
| // eof closes the buffer. Reads from the buffer once all | |||||
| // the data has been consumed will receive os.EOF. | |||||
| func (b *buffer) eof() error { | |||||
| b.Cond.L.Lock() | |||||
| b.closed = true | |||||
| b.Cond.Signal() | |||||
| b.Cond.L.Unlock() | |||||
| return nil | |||||
| } | |||||
| // Read reads data from the internal buffer in buf. Reads will block | |||||
| // if no data is available, or until the buffer is closed. | |||||
| func (b *buffer) Read(buf []byte) (n int, err error) { | |||||
| b.Cond.L.Lock() | |||||
| defer b.Cond.L.Unlock() | |||||
| for len(buf) > 0 { | |||||
| // if there is data in b.head, copy it | |||||
| if len(b.head.buf) > 0 { | |||||
| r := copy(buf, b.head.buf) | |||||
| buf, b.head.buf = buf[r:], b.head.buf[r:] | |||||
| n += r | |||||
| continue | |||||
| } | |||||
| // if there is a next buffer, make it the head | |||||
| if len(b.head.buf) == 0 && b.head != b.tail { | |||||
| b.head = b.head.next | |||||
| continue | |||||
| } | |||||
| // if at least one byte has been copied, return | |||||
| if n > 0 { | |||||
| break | |||||
| } | |||||
| // if nothing was read, and there is nothing outstanding | |||||
| // check to see if the buffer is closed. | |||||
| if b.closed { | |||||
| err = io.EOF | |||||
| break | |||||
| } | |||||
| // out of buffers, wait for producer | |||||
| b.Cond.Wait() | |||||
| } | |||||
| return | |||||
| } | |||||
| @@ -1,87 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "io" | |||||
| "testing" | |||||
| ) | |||||
| var alphabet = []byte("abcdefghijklmnopqrstuvwxyz") | |||||
| func TestBufferReadwrite(t *testing.T) { | |||||
| b := newBuffer() | |||||
| b.write(alphabet[:10]) | |||||
| r, _ := b.Read(make([]byte, 10)) | |||||
| if r != 10 { | |||||
| t.Fatalf("Expected written == read == 10, written: 10, read %d", r) | |||||
| } | |||||
| b = newBuffer() | |||||
| b.write(alphabet[:5]) | |||||
| r, _ = b.Read(make([]byte, 10)) | |||||
| if r != 5 { | |||||
| t.Fatalf("Expected written == read == 5, written: 5, read %d", r) | |||||
| } | |||||
| b = newBuffer() | |||||
| b.write(alphabet[:10]) | |||||
| r, _ = b.Read(make([]byte, 5)) | |||||
| if r != 5 { | |||||
| t.Fatalf("Expected written == 10, read == 5, written: 10, read %d", r) | |||||
| } | |||||
| b = newBuffer() | |||||
| b.write(alphabet[:5]) | |||||
| b.write(alphabet[5:15]) | |||||
| r, _ = b.Read(make([]byte, 10)) | |||||
| r2, _ := b.Read(make([]byte, 10)) | |||||
| if r != 10 || r2 != 5 || 15 != r+r2 { | |||||
| t.Fatal("Expected written == read == 15") | |||||
| } | |||||
| } | |||||
| func TestBufferClose(t *testing.T) { | |||||
| b := newBuffer() | |||||
| b.write(alphabet[:10]) | |||||
| b.eof() | |||||
| _, err := b.Read(make([]byte, 5)) | |||||
| if err != nil { | |||||
| t.Fatal("expected read of 5 to not return EOF") | |||||
| } | |||||
| b = newBuffer() | |||||
| b.write(alphabet[:10]) | |||||
| b.eof() | |||||
| r, err := b.Read(make([]byte, 5)) | |||||
| r2, err2 := b.Read(make([]byte, 10)) | |||||
| if r != 5 || r2 != 5 || err != nil || err2 != nil { | |||||
| t.Fatal("expected reads of 5 and 5") | |||||
| } | |||||
| b = newBuffer() | |||||
| b.write(alphabet[:10]) | |||||
| b.eof() | |||||
| r, err = b.Read(make([]byte, 5)) | |||||
| r2, err2 = b.Read(make([]byte, 10)) | |||||
| r3, err3 := b.Read(make([]byte, 10)) | |||||
| if r != 5 || r2 != 5 || r3 != 0 || err != nil || err2 != nil || err3 != io.EOF { | |||||
| t.Fatal("expected reads of 5 and 5 and 0, with EOF") | |||||
| } | |||||
| b = newBuffer() | |||||
| b.write(make([]byte, 5)) | |||||
| b.write(make([]byte, 10)) | |||||
| b.eof() | |||||
| r, err = b.Read(make([]byte, 9)) | |||||
| r2, err2 = b.Read(make([]byte, 3)) | |||||
| r3, err3 = b.Read(make([]byte, 3)) | |||||
| r4, err4 := b.Read(make([]byte, 10)) | |||||
| if err != nil || err2 != nil || err3 != nil || err4 != io.EOF { | |||||
| t.Fatalf("Expected EOF on forth read only, err=%v, err2=%v, err3=%v, err4=%v", err, err2, err3, err4) | |||||
| } | |||||
| if r != 9 || r2 != 3 || r3 != 3 || r4 != 0 { | |||||
| t.Fatal("Expected written == read == 15", r, r2, r3, r4) | |||||
| } | |||||
| } | |||||
| @@ -1,501 +0,0 @@ | |||||
| // Copyright 2012 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "bytes" | |||||
| "errors" | |||||
| "fmt" | |||||
| "io" | |||||
| "net" | |||||
| "sort" | |||||
| "time" | |||||
| ) | |||||
| // These constants from [PROTOCOL.certkeys] represent the algorithm names | |||||
| // for certificate types supported by this package. | |||||
| const ( | |||||
| CertAlgoRSAv01 = "ssh-rsa-cert-v01@openssh.com" | |||||
| CertAlgoDSAv01 = "ssh-dss-cert-v01@openssh.com" | |||||
| CertAlgoECDSA256v01 = "ecdsa-sha2-nistp256-cert-v01@openssh.com" | |||||
| CertAlgoECDSA384v01 = "ecdsa-sha2-nistp384-cert-v01@openssh.com" | |||||
| CertAlgoECDSA521v01 = "ecdsa-sha2-nistp521-cert-v01@openssh.com" | |||||
| ) | |||||
| // Certificate types distinguish between host and user | |||||
| // certificates. The values can be set in the CertType field of | |||||
| // Certificate. | |||||
| const ( | |||||
| UserCert = 1 | |||||
| HostCert = 2 | |||||
| ) | |||||
| // Signature represents a cryptographic signature. | |||||
| type Signature struct { | |||||
| Format string | |||||
| Blob []byte | |||||
| } | |||||
| // CertTimeInfinity can be used for OpenSSHCertV01.ValidBefore to indicate that | |||||
| // a certificate does not expire. | |||||
| const CertTimeInfinity = 1<<64 - 1 | |||||
| // An Certificate represents an OpenSSH certificate as defined in | |||||
| // [PROTOCOL.certkeys]?rev=1.8. | |||||
| type Certificate struct { | |||||
| Nonce []byte | |||||
| Key PublicKey | |||||
| Serial uint64 | |||||
| CertType uint32 | |||||
| KeyId string | |||||
| ValidPrincipals []string | |||||
| ValidAfter uint64 | |||||
| ValidBefore uint64 | |||||
| Permissions | |||||
| Reserved []byte | |||||
| SignatureKey PublicKey | |||||
| Signature *Signature | |||||
| } | |||||
| // genericCertData holds the key-independent part of the certificate data. | |||||
| // Overall, certificates contain an nonce, public key fields and | |||||
| // key-independent fields. | |||||
| type genericCertData struct { | |||||
| Serial uint64 | |||||
| CertType uint32 | |||||
| KeyId string | |||||
| ValidPrincipals []byte | |||||
| ValidAfter uint64 | |||||
| ValidBefore uint64 | |||||
| CriticalOptions []byte | |||||
| Extensions []byte | |||||
| Reserved []byte | |||||
| SignatureKey []byte | |||||
| Signature []byte | |||||
| } | |||||
| func marshalStringList(namelist []string) []byte { | |||||
| var to []byte | |||||
| for _, name := range namelist { | |||||
| s := struct{ N string }{name} | |||||
| to = append(to, Marshal(&s)...) | |||||
| } | |||||
| return to | |||||
| } | |||||
| type optionsTuple struct { | |||||
| Key string | |||||
| Value []byte | |||||
| } | |||||
| type optionsTupleValue struct { | |||||
| Value string | |||||
| } | |||||
| // serialize a map of critical options or extensions | |||||
| // issue #10569 - per [PROTOCOL.certkeys] and SSH implementation, | |||||
| // we need two length prefixes for a non-empty string value | |||||
| func marshalTuples(tups map[string]string) []byte { | |||||
| keys := make([]string, 0, len(tups)) | |||||
| for key := range tups { | |||||
| keys = append(keys, key) | |||||
| } | |||||
| sort.Strings(keys) | |||||
| var ret []byte | |||||
| for _, key := range keys { | |||||
| s := optionsTuple{Key: key} | |||||
| if value := tups[key]; len(value) > 0 { | |||||
| s.Value = Marshal(&optionsTupleValue{value}) | |||||
| } | |||||
| ret = append(ret, Marshal(&s)...) | |||||
| } | |||||
| return ret | |||||
| } | |||||
| // issue #10569 - per [PROTOCOL.certkeys] and SSH implementation, | |||||
| // we need two length prefixes for a non-empty option value | |||||
| func parseTuples(in []byte) (map[string]string, error) { | |||||
| tups := map[string]string{} | |||||
| var lastKey string | |||||
| var haveLastKey bool | |||||
| for len(in) > 0 { | |||||
| var key, val, extra []byte | |||||
| var ok bool | |||||
| if key, in, ok = parseString(in); !ok { | |||||
| return nil, errShortRead | |||||
| } | |||||
| keyStr := string(key) | |||||
| // according to [PROTOCOL.certkeys], the names must be in | |||||
| // lexical order. | |||||
| if haveLastKey && keyStr <= lastKey { | |||||
| return nil, fmt.Errorf("ssh: certificate options are not in lexical order") | |||||
| } | |||||
| lastKey, haveLastKey = keyStr, true | |||||
| // the next field is a data field, which if non-empty has a string embedded | |||||
| if val, in, ok = parseString(in); !ok { | |||||
| return nil, errShortRead | |||||
| } | |||||
| if len(val) > 0 { | |||||
| val, extra, ok = parseString(val) | |||||
| if !ok { | |||||
| return nil, errShortRead | |||||
| } | |||||
| if len(extra) > 0 { | |||||
| return nil, fmt.Errorf("ssh: unexpected trailing data after certificate option value") | |||||
| } | |||||
| tups[keyStr] = string(val) | |||||
| } else { | |||||
| tups[keyStr] = "" | |||||
| } | |||||
| } | |||||
| return tups, nil | |||||
| } | |||||
| func parseCert(in []byte, privAlgo string) (*Certificate, error) { | |||||
| nonce, rest, ok := parseString(in) | |||||
| if !ok { | |||||
| return nil, errShortRead | |||||
| } | |||||
| key, rest, err := parsePubKey(rest, privAlgo) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| var g genericCertData | |||||
| if err := Unmarshal(rest, &g); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| c := &Certificate{ | |||||
| Nonce: nonce, | |||||
| Key: key, | |||||
| Serial: g.Serial, | |||||
| CertType: g.CertType, | |||||
| KeyId: g.KeyId, | |||||
| ValidAfter: g.ValidAfter, | |||||
| ValidBefore: g.ValidBefore, | |||||
| } | |||||
| for principals := g.ValidPrincipals; len(principals) > 0; { | |||||
| principal, rest, ok := parseString(principals) | |||||
| if !ok { | |||||
| return nil, errShortRead | |||||
| } | |||||
| c.ValidPrincipals = append(c.ValidPrincipals, string(principal)) | |||||
| principals = rest | |||||
| } | |||||
| c.CriticalOptions, err = parseTuples(g.CriticalOptions) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| c.Extensions, err = parseTuples(g.Extensions) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| c.Reserved = g.Reserved | |||||
| k, err := ParsePublicKey(g.SignatureKey) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| c.SignatureKey = k | |||||
| c.Signature, rest, ok = parseSignatureBody(g.Signature) | |||||
| if !ok || len(rest) > 0 { | |||||
| return nil, errors.New("ssh: signature parse error") | |||||
| } | |||||
| return c, nil | |||||
| } | |||||
| type openSSHCertSigner struct { | |||||
| pub *Certificate | |||||
| signer Signer | |||||
| } | |||||
| // NewCertSigner returns a Signer that signs with the given Certificate, whose | |||||
| // private key is held by signer. It returns an error if the public key in cert | |||||
| // doesn't match the key used by signer. | |||||
| func NewCertSigner(cert *Certificate, signer Signer) (Signer, error) { | |||||
| if bytes.Compare(cert.Key.Marshal(), signer.PublicKey().Marshal()) != 0 { | |||||
| return nil, errors.New("ssh: signer and cert have different public key") | |||||
| } | |||||
| return &openSSHCertSigner{cert, signer}, nil | |||||
| } | |||||
| func (s *openSSHCertSigner) Sign(rand io.Reader, data []byte) (*Signature, error) { | |||||
| return s.signer.Sign(rand, data) | |||||
| } | |||||
| func (s *openSSHCertSigner) PublicKey() PublicKey { | |||||
| return s.pub | |||||
| } | |||||
| const sourceAddressCriticalOption = "source-address" | |||||
| // CertChecker does the work of verifying a certificate. Its methods | |||||
| // can be plugged into ClientConfig.HostKeyCallback and | |||||
| // ServerConfig.PublicKeyCallback. For the CertChecker to work, | |||||
| // minimally, the IsAuthority callback should be set. | |||||
| type CertChecker struct { | |||||
| // SupportedCriticalOptions lists the CriticalOptions that the | |||||
| // server application layer understands. These are only used | |||||
| // for user certificates. | |||||
| SupportedCriticalOptions []string | |||||
| // IsAuthority should return true if the key is recognized as | |||||
| // an authority. This allows for certificates to be signed by other | |||||
| // certificates. | |||||
| IsAuthority func(auth PublicKey) bool | |||||
| // Clock is used for verifying time stamps. If nil, time.Now | |||||
| // is used. | |||||
| Clock func() time.Time | |||||
| // UserKeyFallback is called when CertChecker.Authenticate encounters a | |||||
| // public key that is not a certificate. It must implement validation | |||||
| // of user keys or else, if nil, all such keys are rejected. | |||||
| UserKeyFallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) | |||||
| // HostKeyFallback is called when CertChecker.CheckHostKey encounters a | |||||
| // public key that is not a certificate. It must implement host key | |||||
| // validation or else, if nil, all such keys are rejected. | |||||
| HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error | |||||
| // IsRevoked is called for each certificate so that revocation checking | |||||
| // can be implemented. It should return true if the given certificate | |||||
| // is revoked and false otherwise. If nil, no certificates are | |||||
| // considered to have been revoked. | |||||
| IsRevoked func(cert *Certificate) bool | |||||
| } | |||||
| // CheckHostKey checks a host key certificate. This method can be | |||||
| // plugged into ClientConfig.HostKeyCallback. | |||||
| func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) error { | |||||
| cert, ok := key.(*Certificate) | |||||
| if !ok { | |||||
| if c.HostKeyFallback != nil { | |||||
| return c.HostKeyFallback(addr, remote, key) | |||||
| } | |||||
| return errors.New("ssh: non-certificate host key") | |||||
| } | |||||
| if cert.CertType != HostCert { | |||||
| return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType) | |||||
| } | |||||
| return c.CheckCert(addr, cert) | |||||
| } | |||||
| // Authenticate checks a user certificate. Authenticate can be used as | |||||
| // a value for ServerConfig.PublicKeyCallback. | |||||
| func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permissions, error) { | |||||
| cert, ok := pubKey.(*Certificate) | |||||
| if !ok { | |||||
| if c.UserKeyFallback != nil { | |||||
| return c.UserKeyFallback(conn, pubKey) | |||||
| } | |||||
| return nil, errors.New("ssh: normal key pairs not accepted") | |||||
| } | |||||
| if cert.CertType != UserCert { | |||||
| return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType) | |||||
| } | |||||
| if err := c.CheckCert(conn.User(), cert); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return &cert.Permissions, nil | |||||
| } | |||||
| // CheckCert checks CriticalOptions, ValidPrincipals, revocation, timestamp and | |||||
| // the signature of the certificate. | |||||
| func (c *CertChecker) CheckCert(principal string, cert *Certificate) error { | |||||
| if c.IsRevoked != nil && c.IsRevoked(cert) { | |||||
| return fmt.Errorf("ssh: certicate serial %d revoked", cert.Serial) | |||||
| } | |||||
| for opt, _ := range cert.CriticalOptions { | |||||
| // sourceAddressCriticalOption will be enforced by | |||||
| // serverAuthenticate | |||||
| if opt == sourceAddressCriticalOption { | |||||
| continue | |||||
| } | |||||
| found := false | |||||
| for _, supp := range c.SupportedCriticalOptions { | |||||
| if supp == opt { | |||||
| found = true | |||||
| break | |||||
| } | |||||
| } | |||||
| if !found { | |||||
| return fmt.Errorf("ssh: unsupported critical option %q in certificate", opt) | |||||
| } | |||||
| } | |||||
| if len(cert.ValidPrincipals) > 0 { | |||||
| // By default, certs are valid for all users/hosts. | |||||
| found := false | |||||
| for _, p := range cert.ValidPrincipals { | |||||
| if p == principal { | |||||
| found = true | |||||
| break | |||||
| } | |||||
| } | |||||
| if !found { | |||||
| return fmt.Errorf("ssh: principal %q not in the set of valid principals for given certificate: %q", principal, cert.ValidPrincipals) | |||||
| } | |||||
| } | |||||
| if !c.IsAuthority(cert.SignatureKey) { | |||||
| return fmt.Errorf("ssh: certificate signed by unrecognized authority") | |||||
| } | |||||
| clock := c.Clock | |||||
| if clock == nil { | |||||
| clock = time.Now | |||||
| } | |||||
| unixNow := clock().Unix() | |||||
| if after := int64(cert.ValidAfter); after < 0 || unixNow < int64(cert.ValidAfter) { | |||||
| return fmt.Errorf("ssh: cert is not yet valid") | |||||
| } | |||||
| if before := int64(cert.ValidBefore); cert.ValidBefore != uint64(CertTimeInfinity) && (unixNow >= before || before < 0) { | |||||
| return fmt.Errorf("ssh: cert has expired") | |||||
| } | |||||
| if err := cert.SignatureKey.Verify(cert.bytesForSigning(), cert.Signature); err != nil { | |||||
| return fmt.Errorf("ssh: certificate signature does not verify") | |||||
| } | |||||
| return nil | |||||
| } | |||||
| // SignCert sets c.SignatureKey to the authority's public key and stores a | |||||
| // Signature, by authority, in the certificate. | |||||
| func (c *Certificate) SignCert(rand io.Reader, authority Signer) error { | |||||
| c.Nonce = make([]byte, 32) | |||||
| if _, err := io.ReadFull(rand, c.Nonce); err != nil { | |||||
| return err | |||||
| } | |||||
| c.SignatureKey = authority.PublicKey() | |||||
| sig, err := authority.Sign(rand, c.bytesForSigning()) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| c.Signature = sig | |||||
| return nil | |||||
| } | |||||
| var certAlgoNames = map[string]string{ | |||||
| KeyAlgoRSA: CertAlgoRSAv01, | |||||
| KeyAlgoDSA: CertAlgoDSAv01, | |||||
| KeyAlgoECDSA256: CertAlgoECDSA256v01, | |||||
| KeyAlgoECDSA384: CertAlgoECDSA384v01, | |||||
| KeyAlgoECDSA521: CertAlgoECDSA521v01, | |||||
| } | |||||
| // certToPrivAlgo returns the underlying algorithm for a certificate algorithm. | |||||
| // Panics if a non-certificate algorithm is passed. | |||||
| func certToPrivAlgo(algo string) string { | |||||
| for privAlgo, pubAlgo := range certAlgoNames { | |||||
| if pubAlgo == algo { | |||||
| return privAlgo | |||||
| } | |||||
| } | |||||
| panic("unknown cert algorithm") | |||||
| } | |||||
| func (cert *Certificate) bytesForSigning() []byte { | |||||
| c2 := *cert | |||||
| c2.Signature = nil | |||||
| out := c2.Marshal() | |||||
| // Drop trailing signature length. | |||||
| return out[:len(out)-4] | |||||
| } | |||||
| // Marshal serializes c into OpenSSH's wire format. It is part of the | |||||
| // PublicKey interface. | |||||
| func (c *Certificate) Marshal() []byte { | |||||
| generic := genericCertData{ | |||||
| Serial: c.Serial, | |||||
| CertType: c.CertType, | |||||
| KeyId: c.KeyId, | |||||
| ValidPrincipals: marshalStringList(c.ValidPrincipals), | |||||
| ValidAfter: uint64(c.ValidAfter), | |||||
| ValidBefore: uint64(c.ValidBefore), | |||||
| CriticalOptions: marshalTuples(c.CriticalOptions), | |||||
| Extensions: marshalTuples(c.Extensions), | |||||
| Reserved: c.Reserved, | |||||
| SignatureKey: c.SignatureKey.Marshal(), | |||||
| } | |||||
| if c.Signature != nil { | |||||
| generic.Signature = Marshal(c.Signature) | |||||
| } | |||||
| genericBytes := Marshal(&generic) | |||||
| keyBytes := c.Key.Marshal() | |||||
| _, keyBytes, _ = parseString(keyBytes) | |||||
| prefix := Marshal(&struct { | |||||
| Name string | |||||
| Nonce []byte | |||||
| Key []byte `ssh:"rest"` | |||||
| }{c.Type(), c.Nonce, keyBytes}) | |||||
| result := make([]byte, 0, len(prefix)+len(genericBytes)) | |||||
| result = append(result, prefix...) | |||||
| result = append(result, genericBytes...) | |||||
| return result | |||||
| } | |||||
| // Type returns the key name. It is part of the PublicKey interface. | |||||
| func (c *Certificate) Type() string { | |||||
| algo, ok := certAlgoNames[c.Key.Type()] | |||||
| if !ok { | |||||
| panic("unknown cert key type") | |||||
| } | |||||
| return algo | |||||
| } | |||||
| // Verify verifies a signature against the certificate's public | |||||
| // key. It is part of the PublicKey interface. | |||||
| func (c *Certificate) Verify(data []byte, sig *Signature) error { | |||||
| return c.Key.Verify(data, sig) | |||||
| } | |||||
| func parseSignatureBody(in []byte) (out *Signature, rest []byte, ok bool) { | |||||
| format, in, ok := parseString(in) | |||||
| if !ok { | |||||
| return | |||||
| } | |||||
| out = &Signature{ | |||||
| Format: string(format), | |||||
| } | |||||
| if out.Blob, in, ok = parseString(in); !ok { | |||||
| return | |||||
| } | |||||
| return out, in, ok | |||||
| } | |||||
| func parseSignature(in []byte) (out *Signature, rest []byte, ok bool) { | |||||
| sigBytes, rest, ok := parseString(in) | |||||
| if !ok { | |||||
| return | |||||
| } | |||||
| out, trailing, ok := parseSignatureBody(sigBytes) | |||||
| if !ok || len(trailing) > 0 { | |||||
| return nil, nil, false | |||||
| } | |||||
| return | |||||
| } | |||||
| @@ -1,216 +0,0 @@ | |||||
| // Copyright 2013 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "bytes" | |||||
| "crypto/rand" | |||||
| "reflect" | |||||
| "testing" | |||||
| "time" | |||||
| ) | |||||
| // Cert generated by ssh-keygen 6.0p1 Debian-4. | |||||
| // % ssh-keygen -s ca-key -I test user-key | |||||
| const exampleSSHCert = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgb1srW/W3ZDjYAO45xLYAwzHBDLsJ4Ux6ICFIkTjb1LEAAAADAQABAAAAYQCkoR51poH0wE8w72cqSB8Sszx+vAhzcMdCO0wqHTj7UNENHWEXGrU0E0UQekD7U+yhkhtoyjbPOVIP7hNa6aRk/ezdh/iUnCIt4Jt1v3Z1h1P+hA4QuYFMHNB+rmjPwAcAAAAAAAAAAAAAAAEAAAAEdGVzdAAAAAAAAAAAAAAAAP//////////AAAAAAAAAIIAAAAVcGVybWl0LVgxMS1mb3J3YXJkaW5nAAAAAAAAABdwZXJtaXQtYWdlbnQtZm9yd2FyZGluZwAAAAAAAAAWcGVybWl0LXBvcnQtZm9yd2FyZGluZwAAAAAAAAAKcGVybWl0LXB0eQAAAAAAAAAOcGVybWl0LXVzZXItcmMAAAAAAAAAAAAAAHcAAAAHc3NoLXJzYQAAAAMBAAEAAABhANFS2kaktpSGc+CcmEKPyw9mJC4nZKxHKTgLVZeaGbFZOvJTNzBspQHdy7Q1uKSfktxpgjZnksiu/tFF9ngyY2KFoc+U88ya95IZUycBGCUbBQ8+bhDtw/icdDGQD5WnUwAAAG8AAAAHc3NoLXJzYQAAAGC8Y9Z2LQKhIhxf52773XaWrXdxP0t3GBVo4A10vUWiYoAGepr6rQIoGGXFxT4B9Gp+nEBJjOwKDXPrAevow0T9ca8gZN+0ykbhSrXLE5Ao48rqr3zP4O1/9P7e6gp0gw8=` | |||||
| func TestParseCert(t *testing.T) { | |||||
| authKeyBytes := []byte(exampleSSHCert) | |||||
| key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) | |||||
| if err != nil { | |||||
| t.Fatalf("ParseAuthorizedKey: %v", err) | |||||
| } | |||||
| if len(rest) > 0 { | |||||
| t.Errorf("rest: got %q, want empty", rest) | |||||
| } | |||||
| if _, ok := key.(*Certificate); !ok { | |||||
| t.Fatalf("got %v (%T), want *Certificate", key, key) | |||||
| } | |||||
| marshaled := MarshalAuthorizedKey(key) | |||||
| // Before comparison, remove the trailing newline that | |||||
| // MarshalAuthorizedKey adds. | |||||
| marshaled = marshaled[:len(marshaled)-1] | |||||
| if !bytes.Equal(authKeyBytes, marshaled) { | |||||
| t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) | |||||
| } | |||||
| } | |||||
| // Cert generated by ssh-keygen OpenSSH_6.8p1 OS X 10.10.3 | |||||
| // % ssh-keygen -s ca -I testcert -O source-address=192.168.1.0/24 -O force-command=/bin/sleep user.pub | |||||
| // user.pub key: ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMN | |||||
| // Critical Options: | |||||
| // force-command /bin/sleep | |||||
| // source-address 192.168.1.0/24 | |||||
| // Extensions: | |||||
| // permit-X11-forwarding | |||||
| // permit-agent-forwarding | |||||
| // permit-port-forwarding | |||||
| // permit-pty | |||||
| // permit-user-rc | |||||
| const exampleSSHCertWithOptions = `ssh-rsa-cert-v01@openssh.com AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgDyysCJY0XrO1n03EeRRoITnTPdjENFmWDs9X58PP3VUAAAADAQABAAABAQDACh1rt2DXfV3hk6fszSQcQ/rueMId0kVD9U7nl8cfEnFxqOCrNT92g4laQIGl2mn8lsGZfTLg8ksHq3gkvgO3oo/0wHy4v32JeBOHTsN5AL4gfHNEhWeWb50ev47hnTsRIt9P4dxogeUo/hTu7j9+s9lLpEQXCvq6xocXQt0j8MV9qZBBXFLXVT3cWIkSqOdwt/5ZBg+1GSrc7WfCXVWgTk4a20uPMuJPxU4RQwZW6X3+O8Pqo8C3cW0OzZRFP6gUYUKUsTI5WntlS+LAxgw1mZNsozFGdbiOPRnEryE3SRldh9vjDR3tin1fGpA5P7+CEB/bqaXtG3V+F2OkqaMNAAAAAAAAAAAAAAABAAAACHRlc3RjZXJ0AAAAAAAAAAAAAAAA//////////8AAABLAAAADWZvcmNlLWNvbW1hbmQAAAAOAAAACi9iaW4vc2xlZXAAAAAOc291cmNlLWFkZHJlc3MAAAASAAAADjE5Mi4xNjguMS4wLzI0AAAAggAAABVwZXJtaXQtWDExLWZvcndhcmRpbmcAAAAAAAAAF3Blcm1pdC1hZ2VudC1mb3J3YXJkaW5nAAAAAAAAABZwZXJtaXQtcG9ydC1mb3J3YXJkaW5nAAAAAAAAAApwZXJtaXQtcHR5AAAAAAAAAA5wZXJtaXQtdXNlci1yYwAAAAAAAAAAAAABFwAAAAdzc2gtcnNhAAAAAwEAAQAAAQEAwU+c5ui5A8+J/CFpjW8wCa52bEODA808WWQDCSuTG/eMXNf59v9Y8Pk0F1E9dGCosSNyVcB/hacUrc6He+i97+HJCyKavBsE6GDxrjRyxYqAlfcOXi/IVmaUGiO8OQ39d4GHrjToInKvExSUeleQyH4Y4/e27T/pILAqPFL3fyrvMLT5qU9QyIt6zIpa7GBP5+urouNavMprV3zsfIqNBbWypinOQAw823a5wN+zwXnhZrgQiHZ/USG09Y6k98y1dTVz8YHlQVR4D3lpTAsKDKJ5hCH9WU4fdf+lU8OyNGaJ/vz0XNqxcToe1l4numLTnaoSuH89pHryjqurB7lJKwAAAQ8AAAAHc3NoLXJzYQAAAQCaHvUIoPL1zWUHIXLvu96/HU1s/i4CAW2IIEuGgxCUCiFj6vyTyYtgxQxcmbfZf6eaITlS6XJZa7Qq4iaFZh75C1DXTX8labXhRSD4E2t//AIP9MC1rtQC5xo6FmbQ+BoKcDskr+mNACcbRSxs3IL3bwCfWDnIw2WbVox9ZdcthJKk4UoCW4ix4QwdHw7zlddlz++fGEEVhmTbll1SUkycGApPFBsAYRTMupUJcYPIeReBI/m8XfkoMk99bV8ZJQTAd7OekHY2/48Ff53jLmyDjP7kNw1F8OaPtkFs6dGJXta4krmaekPy87j+35In5hFj7yoOqvSbmYUkeX70/GGQ` | |||||
| func TestParseCertWithOptions(t *testing.T) { | |||||
| opts := map[string]string{ | |||||
| "source-address": "192.168.1.0/24", | |||||
| "force-command": "/bin/sleep", | |||||
| } | |||||
| exts := map[string]string{ | |||||
| "permit-X11-forwarding": "", | |||||
| "permit-agent-forwarding": "", | |||||
| "permit-port-forwarding": "", | |||||
| "permit-pty": "", | |||||
| "permit-user-rc": "", | |||||
| } | |||||
| authKeyBytes := []byte(exampleSSHCertWithOptions) | |||||
| key, _, _, rest, err := ParseAuthorizedKey(authKeyBytes) | |||||
| if err != nil { | |||||
| t.Fatalf("ParseAuthorizedKey: %v", err) | |||||
| } | |||||
| if len(rest) > 0 { | |||||
| t.Errorf("rest: got %q, want empty", rest) | |||||
| } | |||||
| cert, ok := key.(*Certificate) | |||||
| if !ok { | |||||
| t.Fatalf("got %v (%T), want *Certificate", key, key) | |||||
| } | |||||
| if !reflect.DeepEqual(cert.CriticalOptions, opts) { | |||||
| t.Errorf("unexpected critical options - got %v, want %v", cert.CriticalOptions, opts) | |||||
| } | |||||
| if !reflect.DeepEqual(cert.Extensions, exts) { | |||||
| t.Errorf("unexpected Extensions - got %v, want %v", cert.Extensions, exts) | |||||
| } | |||||
| marshaled := MarshalAuthorizedKey(key) | |||||
| // Before comparison, remove the trailing newline that | |||||
| // MarshalAuthorizedKey adds. | |||||
| marshaled = marshaled[:len(marshaled)-1] | |||||
| if !bytes.Equal(authKeyBytes, marshaled) { | |||||
| t.Errorf("marshaled certificate does not match original: got %q, want %q", marshaled, authKeyBytes) | |||||
| } | |||||
| } | |||||
| func TestValidateCert(t *testing.T) { | |||||
| key, _, _, _, err := ParseAuthorizedKey([]byte(exampleSSHCert)) | |||||
| if err != nil { | |||||
| t.Fatalf("ParseAuthorizedKey: %v", err) | |||||
| } | |||||
| validCert, ok := key.(*Certificate) | |||||
| if !ok { | |||||
| t.Fatalf("got %v (%T), want *Certificate", key, key) | |||||
| } | |||||
| checker := CertChecker{} | |||||
| checker.IsAuthority = func(k PublicKey) bool { | |||||
| return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal()) | |||||
| } | |||||
| if err := checker.CheckCert("user", validCert); err != nil { | |||||
| t.Errorf("Unable to validate certificate: %v", err) | |||||
| } | |||||
| invalidCert := &Certificate{ | |||||
| Key: testPublicKeys["rsa"], | |||||
| SignatureKey: testPublicKeys["ecdsa"], | |||||
| ValidBefore: CertTimeInfinity, | |||||
| Signature: &Signature{}, | |||||
| } | |||||
| if err := checker.CheckCert("user", invalidCert); err == nil { | |||||
| t.Error("Invalid cert signature passed validation") | |||||
| } | |||||
| } | |||||
| func TestValidateCertTime(t *testing.T) { | |||||
| cert := Certificate{ | |||||
| ValidPrincipals: []string{"user"}, | |||||
| Key: testPublicKeys["rsa"], | |||||
| ValidAfter: 50, | |||||
| ValidBefore: 100, | |||||
| } | |||||
| cert.SignCert(rand.Reader, testSigners["ecdsa"]) | |||||
| for ts, ok := range map[int64]bool{ | |||||
| 25: false, | |||||
| 50: true, | |||||
| 99: true, | |||||
| 100: false, | |||||
| 125: false, | |||||
| } { | |||||
| checker := CertChecker{ | |||||
| Clock: func() time.Time { return time.Unix(ts, 0) }, | |||||
| } | |||||
| checker.IsAuthority = func(k PublicKey) bool { | |||||
| return bytes.Equal(k.Marshal(), | |||||
| testPublicKeys["ecdsa"].Marshal()) | |||||
| } | |||||
| if v := checker.CheckCert("user", &cert); (v == nil) != ok { | |||||
| t.Errorf("Authenticate(%d): %v", ts, v) | |||||
| } | |||||
| } | |||||
| } | |||||
| // TODO(hanwen): tests for | |||||
| // | |||||
| // host keys: | |||||
| // * fallbacks | |||||
| func TestHostKeyCert(t *testing.T) { | |||||
| cert := &Certificate{ | |||||
| ValidPrincipals: []string{"hostname", "hostname.domain"}, | |||||
| Key: testPublicKeys["rsa"], | |||||
| ValidBefore: CertTimeInfinity, | |||||
| CertType: HostCert, | |||||
| } | |||||
| cert.SignCert(rand.Reader, testSigners["ecdsa"]) | |||||
| checker := &CertChecker{ | |||||
| IsAuthority: func(p PublicKey) bool { | |||||
| return bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal()) | |||||
| }, | |||||
| } | |||||
| certSigner, err := NewCertSigner(cert, testSigners["rsa"]) | |||||
| if err != nil { | |||||
| t.Errorf("NewCertSigner: %v", err) | |||||
| } | |||||
| for _, name := range []string{"hostname", "otherhost"} { | |||||
| c1, c2, err := netPipe() | |||||
| if err != nil { | |||||
| t.Fatalf("netPipe: %v", err) | |||||
| } | |||||
| defer c1.Close() | |||||
| defer c2.Close() | |||||
| errc := make(chan error) | |||||
| go func() { | |||||
| conf := ServerConfig{ | |||||
| NoClientAuth: true, | |||||
| } | |||||
| conf.AddHostKey(certSigner) | |||||
| _, _, _, err := NewServerConn(c1, &conf) | |||||
| errc <- err | |||||
| }() | |||||
| config := &ClientConfig{ | |||||
| User: "user", | |||||
| HostKeyCallback: checker.CheckHostKey, | |||||
| } | |||||
| _, _, _, err = NewClientConn(c2, name, config) | |||||
| succeed := name == "hostname" | |||||
| if (err == nil) != succeed { | |||||
| t.Fatalf("NewClientConn(%q): %v", name, err) | |||||
| } | |||||
| err = <-errc | |||||
| if (err == nil) != succeed { | |||||
| t.Fatalf("NewServerConn(%q): %v", name, err) | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,631 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "encoding/binary" | |||||
| "errors" | |||||
| "fmt" | |||||
| "io" | |||||
| "log" | |||||
| "sync" | |||||
| ) | |||||
| const ( | |||||
| minPacketLength = 9 | |||||
| // channelMaxPacket contains the maximum number of bytes that will be | |||||
| // sent in a single packet. As per RFC 4253, section 6.1, 32k is also | |||||
| // the minimum. | |||||
| channelMaxPacket = 1 << 15 | |||||
| // We follow OpenSSH here. | |||||
| channelWindowSize = 64 * channelMaxPacket | |||||
| ) | |||||
| // NewChannel represents an incoming request to a channel. It must either be | |||||
| // accepted for use by calling Accept, or rejected by calling Reject. | |||||
| type NewChannel interface { | |||||
| // Accept accepts the channel creation request. It returns the Channel | |||||
| // and a Go channel containing SSH requests. The Go channel must be | |||||
| // serviced otherwise the Channel will hang. | |||||
| Accept() (Channel, <-chan *Request, error) | |||||
| // Reject rejects the channel creation request. After calling | |||||
| // this, no other methods on the Channel may be called. | |||||
| Reject(reason RejectionReason, message string) error | |||||
| // ChannelType returns the type of the channel, as supplied by the | |||||
| // client. | |||||
| ChannelType() string | |||||
| // ExtraData returns the arbitrary payload for this channel, as supplied | |||||
| // by the client. This data is specific to the channel type. | |||||
| ExtraData() []byte | |||||
| } | |||||
| // A Channel is an ordered, reliable, flow-controlled, duplex stream | |||||
| // that is multiplexed over an SSH connection. | |||||
| type Channel interface { | |||||
| // Read reads up to len(data) bytes from the channel. | |||||
| Read(data []byte) (int, error) | |||||
| // Write writes len(data) bytes to the channel. | |||||
| Write(data []byte) (int, error) | |||||
| // Close signals end of channel use. No data may be sent after this | |||||
| // call. | |||||
| Close() error | |||||
| // CloseWrite signals the end of sending in-band | |||||
| // data. Requests may still be sent, and the other side may | |||||
| // still send data | |||||
| CloseWrite() error | |||||
| // SendRequest sends a channel request. If wantReply is true, | |||||
| // it will wait for a reply and return the result as a | |||||
| // boolean, otherwise the return value will be false. Channel | |||||
| // requests are out-of-band messages so they may be sent even | |||||
| // if the data stream is closed or blocked by flow control. | |||||
| SendRequest(name string, wantReply bool, payload []byte) (bool, error) | |||||
| // Stderr returns an io.ReadWriter that writes to this channel | |||||
| // with the extended data type set to stderr. Stderr may | |||||
| // safely be read and written from a different goroutine than | |||||
| // Read and Write respectively. | |||||
| Stderr() io.ReadWriter | |||||
| } | |||||
| // Request is a request sent outside of the normal stream of | |||||
| // data. Requests can either be specific to an SSH channel, or they | |||||
| // can be global. | |||||
| type Request struct { | |||||
| Type string | |||||
| WantReply bool | |||||
| Payload []byte | |||||
| ch *channel | |||||
| mux *mux | |||||
| } | |||||
| // Reply sends a response to a request. It must be called for all requests | |||||
| // where WantReply is true and is a no-op otherwise. The payload argument is | |||||
| // ignored for replies to channel-specific requests. | |||||
| func (r *Request) Reply(ok bool, payload []byte) error { | |||||
| if !r.WantReply { | |||||
| return nil | |||||
| } | |||||
| if r.ch == nil { | |||||
| return r.mux.ackRequest(ok, payload) | |||||
| } | |||||
| return r.ch.ackRequest(ok) | |||||
| } | |||||
| // RejectionReason is an enumeration used when rejecting channel creation | |||||
| // requests. See RFC 4254, section 5.1. | |||||
| type RejectionReason uint32 | |||||
| const ( | |||||
| Prohibited RejectionReason = iota + 1 | |||||
| ConnectionFailed | |||||
| UnknownChannelType | |||||
| ResourceShortage | |||||
| ) | |||||
| // String converts the rejection reason to human readable form. | |||||
| func (r RejectionReason) String() string { | |||||
| switch r { | |||||
| case Prohibited: | |||||
| return "administratively prohibited" | |||||
| case ConnectionFailed: | |||||
| return "connect failed" | |||||
| case UnknownChannelType: | |||||
| return "unknown channel type" | |||||
| case ResourceShortage: | |||||
| return "resource shortage" | |||||
| } | |||||
| return fmt.Sprintf("unknown reason %d", int(r)) | |||||
| } | |||||
| func min(a uint32, b int) uint32 { | |||||
| if a < uint32(b) { | |||||
| return a | |||||
| } | |||||
| return uint32(b) | |||||
| } | |||||
| type channelDirection uint8 | |||||
| const ( | |||||
| channelInbound channelDirection = iota | |||||
| channelOutbound | |||||
| ) | |||||
| // channel is an implementation of the Channel interface that works | |||||
| // with the mux class. | |||||
| type channel struct { | |||||
| // R/O after creation | |||||
| chanType string | |||||
| extraData []byte | |||||
| localId, remoteId uint32 | |||||
| // maxIncomingPayload and maxRemotePayload are the maximum | |||||
| // payload sizes of normal and extended data packets for | |||||
| // receiving and sending, respectively. The wire packet will | |||||
| // be 9 or 13 bytes larger (excluding encryption overhead). | |||||
| maxIncomingPayload uint32 | |||||
| maxRemotePayload uint32 | |||||
| mux *mux | |||||
| // decided is set to true if an accept or reject message has been sent | |||||
| // (for outbound channels) or received (for inbound channels). | |||||
| decided bool | |||||
| // direction contains either channelOutbound, for channels created | |||||
| // locally, or channelInbound, for channels created by the peer. | |||||
| direction channelDirection | |||||
| // Pending internal channel messages. | |||||
| msg chan interface{} | |||||
| // Since requests have no ID, there can be only one request | |||||
| // with WantReply=true outstanding. This lock is held by a | |||||
| // goroutine that has such an outgoing request pending. | |||||
| sentRequestMu sync.Mutex | |||||
| incomingRequests chan *Request | |||||
| sentEOF bool | |||||
| // thread-safe data | |||||
| remoteWin window | |||||
| pending *buffer | |||||
| extPending *buffer | |||||
| // windowMu protects myWindow, the flow-control window. | |||||
| windowMu sync.Mutex | |||||
| myWindow uint32 | |||||
| // writeMu serializes calls to mux.conn.writePacket() and | |||||
| // protects sentClose and packetPool. This mutex must be | |||||
| // different from windowMu, as writePacket can block if there | |||||
| // is a key exchange pending. | |||||
| writeMu sync.Mutex | |||||
| sentClose bool | |||||
| // packetPool has a buffer for each extended channel ID to | |||||
| // save allocations during writes. | |||||
| packetPool map[uint32][]byte | |||||
| } | |||||
| // writePacket sends a packet. If the packet is a channel close, it updates | |||||
| // sentClose. This method takes the lock c.writeMu. | |||||
| func (c *channel) writePacket(packet []byte) error { | |||||
| c.writeMu.Lock() | |||||
| if c.sentClose { | |||||
| c.writeMu.Unlock() | |||||
| return io.EOF | |||||
| } | |||||
| c.sentClose = (packet[0] == msgChannelClose) | |||||
| err := c.mux.conn.writePacket(packet) | |||||
| c.writeMu.Unlock() | |||||
| return err | |||||
| } | |||||
| func (c *channel) sendMessage(msg interface{}) error { | |||||
| if debugMux { | |||||
| log.Printf("send %d: %#v", c.mux.chanList.offset, msg) | |||||
| } | |||||
| p := Marshal(msg) | |||||
| binary.BigEndian.PutUint32(p[1:], c.remoteId) | |||||
| return c.writePacket(p) | |||||
| } | |||||
| // WriteExtended writes data to a specific extended stream. These streams are | |||||
| // used, for example, for stderr. | |||||
| func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) { | |||||
| if c.sentEOF { | |||||
| return 0, io.EOF | |||||
| } | |||||
| // 1 byte message type, 4 bytes remoteId, 4 bytes data length | |||||
| opCode := byte(msgChannelData) | |||||
| headerLength := uint32(9) | |||||
| if extendedCode > 0 { | |||||
| headerLength += 4 | |||||
| opCode = msgChannelExtendedData | |||||
| } | |||||
| c.writeMu.Lock() | |||||
| packet := c.packetPool[extendedCode] | |||||
| // We don't remove the buffer from packetPool, so | |||||
| // WriteExtended calls from different goroutines will be | |||||
| // flagged as errors by the race detector. | |||||
| c.writeMu.Unlock() | |||||
| for len(data) > 0 { | |||||
| space := min(c.maxRemotePayload, len(data)) | |||||
| if space, err = c.remoteWin.reserve(space); err != nil { | |||||
| return n, err | |||||
| } | |||||
| if want := headerLength + space; uint32(cap(packet)) < want { | |||||
| packet = make([]byte, want) | |||||
| } else { | |||||
| packet = packet[:want] | |||||
| } | |||||
| todo := data[:space] | |||||
| packet[0] = opCode | |||||
| binary.BigEndian.PutUint32(packet[1:], c.remoteId) | |||||
| if extendedCode > 0 { | |||||
| binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode)) | |||||
| } | |||||
| binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo))) | |||||
| copy(packet[headerLength:], todo) | |||||
| if err = c.writePacket(packet); err != nil { | |||||
| return n, err | |||||
| } | |||||
| n += len(todo) | |||||
| data = data[len(todo):] | |||||
| } | |||||
| c.writeMu.Lock() | |||||
| c.packetPool[extendedCode] = packet | |||||
| c.writeMu.Unlock() | |||||
| return n, err | |||||
| } | |||||
| func (c *channel) handleData(packet []byte) error { | |||||
| headerLen := 9 | |||||
| isExtendedData := packet[0] == msgChannelExtendedData | |||||
| if isExtendedData { | |||||
| headerLen = 13 | |||||
| } | |||||
| if len(packet) < headerLen { | |||||
| // malformed data packet | |||||
| return parseError(packet[0]) | |||||
| } | |||||
| var extended uint32 | |||||
| if isExtendedData { | |||||
| extended = binary.BigEndian.Uint32(packet[5:]) | |||||
| } | |||||
| length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen]) | |||||
| if length == 0 { | |||||
| return nil | |||||
| } | |||||
| if length > c.maxIncomingPayload { | |||||
| // TODO(hanwen): should send Disconnect? | |||||
| return errors.New("ssh: incoming packet exceeds maximum payload size") | |||||
| } | |||||
| data := packet[headerLen:] | |||||
| if length != uint32(len(data)) { | |||||
| return errors.New("ssh: wrong packet length") | |||||
| } | |||||
| c.windowMu.Lock() | |||||
| if c.myWindow < length { | |||||
| c.windowMu.Unlock() | |||||
| // TODO(hanwen): should send Disconnect with reason? | |||||
| return errors.New("ssh: remote side wrote too much") | |||||
| } | |||||
| c.myWindow -= length | |||||
| c.windowMu.Unlock() | |||||
| if extended == 1 { | |||||
| c.extPending.write(data) | |||||
| } else if extended > 0 { | |||||
| // discard other extended data. | |||||
| } else { | |||||
| c.pending.write(data) | |||||
| } | |||||
| return nil | |||||
| } | |||||
| func (c *channel) adjustWindow(n uint32) error { | |||||
| c.windowMu.Lock() | |||||
| // Since myWindow is managed on our side, and can never exceed | |||||
| // the initial window setting, we don't worry about overflow. | |||||
| c.myWindow += uint32(n) | |||||
| c.windowMu.Unlock() | |||||
| return c.sendMessage(windowAdjustMsg{ | |||||
| AdditionalBytes: uint32(n), | |||||
| }) | |||||
| } | |||||
| func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) { | |||||
| switch extended { | |||||
| case 1: | |||||
| n, err = c.extPending.Read(data) | |||||
| case 0: | |||||
| n, err = c.pending.Read(data) | |||||
| default: | |||||
| return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended) | |||||
| } | |||||
| if n > 0 { | |||||
| err = c.adjustWindow(uint32(n)) | |||||
| // sendWindowAdjust can return io.EOF if the remote | |||||
| // peer has closed the connection, however we want to | |||||
| // defer forwarding io.EOF to the caller of Read until | |||||
| // the buffer has been drained. | |||||
| if n > 0 && err == io.EOF { | |||||
| err = nil | |||||
| } | |||||
| } | |||||
| return n, err | |||||
| } | |||||
| func (c *channel) close() { | |||||
| c.pending.eof() | |||||
| c.extPending.eof() | |||||
| close(c.msg) | |||||
| close(c.incomingRequests) | |||||
| c.writeMu.Lock() | |||||
| // This is not necesary for a normal channel teardown, but if | |||||
| // there was another error, it is. | |||||
| c.sentClose = true | |||||
| c.writeMu.Unlock() | |||||
| // Unblock writers. | |||||
| c.remoteWin.close() | |||||
| } | |||||
| // responseMessageReceived is called when a success or failure message is | |||||
| // received on a channel to check that such a message is reasonable for the | |||||
| // given channel. | |||||
| func (c *channel) responseMessageReceived() error { | |||||
| if c.direction == channelInbound { | |||||
| return errors.New("ssh: channel response message received on inbound channel") | |||||
| } | |||||
| if c.decided { | |||||
| return errors.New("ssh: duplicate response received for channel") | |||||
| } | |||||
| c.decided = true | |||||
| return nil | |||||
| } | |||||
| func (c *channel) handlePacket(packet []byte) error { | |||||
| switch packet[0] { | |||||
| case msgChannelData, msgChannelExtendedData: | |||||
| return c.handleData(packet) | |||||
| case msgChannelClose: | |||||
| c.sendMessage(channelCloseMsg{PeersId: c.remoteId}) | |||||
| c.mux.chanList.remove(c.localId) | |||||
| c.close() | |||||
| return nil | |||||
| case msgChannelEOF: | |||||
| // RFC 4254 is mute on how EOF affects dataExt messages but | |||||
| // it is logical to signal EOF at the same time. | |||||
| c.extPending.eof() | |||||
| c.pending.eof() | |||||
| return nil | |||||
| } | |||||
| decoded, err := decode(packet) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| switch msg := decoded.(type) { | |||||
| case *channelOpenFailureMsg: | |||||
| if err := c.responseMessageReceived(); err != nil { | |||||
| return err | |||||
| } | |||||
| c.mux.chanList.remove(msg.PeersId) | |||||
| c.msg <- msg | |||||
| case *channelOpenConfirmMsg: | |||||
| if err := c.responseMessageReceived(); err != nil { | |||||
| return err | |||||
| } | |||||
| if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { | |||||
| return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize) | |||||
| } | |||||
| c.remoteId = msg.MyId | |||||
| c.maxRemotePayload = msg.MaxPacketSize | |||||
| c.remoteWin.add(msg.MyWindow) | |||||
| c.msg <- msg | |||||
| case *windowAdjustMsg: | |||||
| if !c.remoteWin.add(msg.AdditionalBytes) { | |||||
| return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes) | |||||
| } | |||||
| case *channelRequestMsg: | |||||
| req := Request{ | |||||
| Type: msg.Request, | |||||
| WantReply: msg.WantReply, | |||||
| Payload: msg.RequestSpecificData, | |||||
| ch: c, | |||||
| } | |||||
| c.incomingRequests <- &req | |||||
| default: | |||||
| c.msg <- msg | |||||
| } | |||||
| return nil | |||||
| } | |||||
| func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel { | |||||
| ch := &channel{ | |||||
| remoteWin: window{Cond: newCond()}, | |||||
| myWindow: channelWindowSize, | |||||
| pending: newBuffer(), | |||||
| extPending: newBuffer(), | |||||
| direction: direction, | |||||
| incomingRequests: make(chan *Request, 16), | |||||
| msg: make(chan interface{}, 16), | |||||
| chanType: chanType, | |||||
| extraData: extraData, | |||||
| mux: m, | |||||
| packetPool: make(map[uint32][]byte), | |||||
| } | |||||
| ch.localId = m.chanList.add(ch) | |||||
| return ch | |||||
| } | |||||
| var errUndecided = errors.New("ssh: must Accept or Reject channel") | |||||
| var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once") | |||||
| type extChannel struct { | |||||
| code uint32 | |||||
| ch *channel | |||||
| } | |||||
| func (e *extChannel) Write(data []byte) (n int, err error) { | |||||
| return e.ch.WriteExtended(data, e.code) | |||||
| } | |||||
| func (e *extChannel) Read(data []byte) (n int, err error) { | |||||
| return e.ch.ReadExtended(data, e.code) | |||||
| } | |||||
| func (c *channel) Accept() (Channel, <-chan *Request, error) { | |||||
| if c.decided { | |||||
| return nil, nil, errDecidedAlready | |||||
| } | |||||
| c.maxIncomingPayload = channelMaxPacket | |||||
| confirm := channelOpenConfirmMsg{ | |||||
| PeersId: c.remoteId, | |||||
| MyId: c.localId, | |||||
| MyWindow: c.myWindow, | |||||
| MaxPacketSize: c.maxIncomingPayload, | |||||
| } | |||||
| c.decided = true | |||||
| if err := c.sendMessage(confirm); err != nil { | |||||
| return nil, nil, err | |||||
| } | |||||
| return c, c.incomingRequests, nil | |||||
| } | |||||
| func (ch *channel) Reject(reason RejectionReason, message string) error { | |||||
| if ch.decided { | |||||
| return errDecidedAlready | |||||
| } | |||||
| reject := channelOpenFailureMsg{ | |||||
| PeersId: ch.remoteId, | |||||
| Reason: reason, | |||||
| Message: message, | |||||
| Language: "en", | |||||
| } | |||||
| ch.decided = true | |||||
| return ch.sendMessage(reject) | |||||
| } | |||||
| func (ch *channel) Read(data []byte) (int, error) { | |||||
| if !ch.decided { | |||||
| return 0, errUndecided | |||||
| } | |||||
| return ch.ReadExtended(data, 0) | |||||
| } | |||||
| func (ch *channel) Write(data []byte) (int, error) { | |||||
| if !ch.decided { | |||||
| return 0, errUndecided | |||||
| } | |||||
| return ch.WriteExtended(data, 0) | |||||
| } | |||||
| func (ch *channel) CloseWrite() error { | |||||
| if !ch.decided { | |||||
| return errUndecided | |||||
| } | |||||
| ch.sentEOF = true | |||||
| return ch.sendMessage(channelEOFMsg{ | |||||
| PeersId: ch.remoteId}) | |||||
| } | |||||
| func (ch *channel) Close() error { | |||||
| if !ch.decided { | |||||
| return errUndecided | |||||
| } | |||||
| return ch.sendMessage(channelCloseMsg{ | |||||
| PeersId: ch.remoteId}) | |||||
| } | |||||
| // Extended returns an io.ReadWriter that sends and receives data on the given, | |||||
| // SSH extended stream. Such streams are used, for example, for stderr. | |||||
| func (ch *channel) Extended(code uint32) io.ReadWriter { | |||||
| if !ch.decided { | |||||
| return nil | |||||
| } | |||||
| return &extChannel{code, ch} | |||||
| } | |||||
| func (ch *channel) Stderr() io.ReadWriter { | |||||
| return ch.Extended(1) | |||||
| } | |||||
| func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { | |||||
| if !ch.decided { | |||||
| return false, errUndecided | |||||
| } | |||||
| if wantReply { | |||||
| ch.sentRequestMu.Lock() | |||||
| defer ch.sentRequestMu.Unlock() | |||||
| } | |||||
| msg := channelRequestMsg{ | |||||
| PeersId: ch.remoteId, | |||||
| Request: name, | |||||
| WantReply: wantReply, | |||||
| RequestSpecificData: payload, | |||||
| } | |||||
| if err := ch.sendMessage(msg); err != nil { | |||||
| return false, err | |||||
| } | |||||
| if wantReply { | |||||
| m, ok := (<-ch.msg) | |||||
| if !ok { | |||||
| return false, io.EOF | |||||
| } | |||||
| switch m.(type) { | |||||
| case *channelRequestFailureMsg: | |||||
| return false, nil | |||||
| case *channelRequestSuccessMsg: | |||||
| return true, nil | |||||
| default: | |||||
| return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m) | |||||
| } | |||||
| } | |||||
| return false, nil | |||||
| } | |||||
| // ackRequest either sends an ack or nack to the channel request. | |||||
| func (ch *channel) ackRequest(ok bool) error { | |||||
| if !ch.decided { | |||||
| return errUndecided | |||||
| } | |||||
| var msg interface{} | |||||
| if !ok { | |||||
| msg = channelRequestFailureMsg{ | |||||
| PeersId: ch.remoteId, | |||||
| } | |||||
| } else { | |||||
| msg = channelRequestSuccessMsg{ | |||||
| PeersId: ch.remoteId, | |||||
| } | |||||
| } | |||||
| return ch.sendMessage(msg) | |||||
| } | |||||
| func (ch *channel) ChannelType() string { | |||||
| return ch.chanType | |||||
| } | |||||
| func (ch *channel) ExtraData() []byte { | |||||
| return ch.extraData | |||||
| } | |||||
| @@ -1,549 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "crypto/aes" | |||||
| "crypto/cipher" | |||||
| "crypto/rc4" | |||||
| "crypto/subtle" | |||||
| "encoding/binary" | |||||
| "errors" | |||||
| "fmt" | |||||
| "hash" | |||||
| "io" | |||||
| "io/ioutil" | |||||
| ) | |||||
| const ( | |||||
| packetSizeMultiple = 16 // TODO(huin) this should be determined by the cipher. | |||||
| // RFC 4253 section 6.1 defines a minimum packet size of 32768 that implementations | |||||
| // MUST be able to process (plus a few more kilobytes for padding and mac). The RFC | |||||
| // indicates implementations SHOULD be able to handle larger packet sizes, but then | |||||
| // waffles on about reasonable limits. | |||||
| // | |||||
| // OpenSSH caps their maxPacket at 256kB so we choose to do | |||||
| // the same. maxPacket is also used to ensure that uint32 | |||||
| // length fields do not overflow, so it should remain well | |||||
| // below 4G. | |||||
| maxPacket = 256 * 1024 | |||||
| ) | |||||
| // noneCipher implements cipher.Stream and provides no encryption. It is used | |||||
| // by the transport before the first key-exchange. | |||||
| type noneCipher struct{} | |||||
| func (c noneCipher) XORKeyStream(dst, src []byte) { | |||||
| copy(dst, src) | |||||
| } | |||||
| func newAESCTR(key, iv []byte) (cipher.Stream, error) { | |||||
| c, err := aes.NewCipher(key) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return cipher.NewCTR(c, iv), nil | |||||
| } | |||||
| func newRC4(key, iv []byte) (cipher.Stream, error) { | |||||
| return rc4.NewCipher(key) | |||||
| } | |||||
| type streamCipherMode struct { | |||||
| keySize int | |||||
| ivSize int | |||||
| skip int | |||||
| createFunc func(key, iv []byte) (cipher.Stream, error) | |||||
| } | |||||
| func (c *streamCipherMode) createStream(key, iv []byte) (cipher.Stream, error) { | |||||
| if len(key) < c.keySize { | |||||
| panic("ssh: key length too small for cipher") | |||||
| } | |||||
| if len(iv) < c.ivSize { | |||||
| panic("ssh: iv too small for cipher") | |||||
| } | |||||
| stream, err := c.createFunc(key[:c.keySize], iv[:c.ivSize]) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| var streamDump []byte | |||||
| if c.skip > 0 { | |||||
| streamDump = make([]byte, 512) | |||||
| } | |||||
| for remainingToDump := c.skip; remainingToDump > 0; { | |||||
| dumpThisTime := remainingToDump | |||||
| if dumpThisTime > len(streamDump) { | |||||
| dumpThisTime = len(streamDump) | |||||
| } | |||||
| stream.XORKeyStream(streamDump[:dumpThisTime], streamDump[:dumpThisTime]) | |||||
| remainingToDump -= dumpThisTime | |||||
| } | |||||
| return stream, nil | |||||
| } | |||||
| // cipherModes documents properties of supported ciphers. Ciphers not included | |||||
| // are not supported and will not be negotiated, even if explicitly requested in | |||||
| // ClientConfig.Crypto.Ciphers. | |||||
| var cipherModes = map[string]*streamCipherMode{ | |||||
| // Ciphers from RFC4344, which introduced many CTR-based ciphers. Algorithms | |||||
| // are defined in the order specified in the RFC. | |||||
| "aes128-ctr": {16, aes.BlockSize, 0, newAESCTR}, | |||||
| "aes192-ctr": {24, aes.BlockSize, 0, newAESCTR}, | |||||
| "aes256-ctr": {32, aes.BlockSize, 0, newAESCTR}, | |||||
| // Ciphers from RFC4345, which introduces security-improved arcfour ciphers. | |||||
| // They are defined in the order specified in the RFC. | |||||
| "arcfour128": {16, 0, 1536, newRC4}, | |||||
| "arcfour256": {32, 0, 1536, newRC4}, | |||||
| // Cipher defined in RFC 4253, which describes SSH Transport Layer Protocol. | |||||
| // Note that this cipher is not safe, as stated in RFC 4253: "Arcfour (and | |||||
| // RC4) has problems with weak keys, and should be used with caution." | |||||
| // RFC4345 introduces improved versions of Arcfour. | |||||
| "arcfour": {16, 0, 0, newRC4}, | |||||
| // AES-GCM is not a stream cipher, so it is constructed with a | |||||
| // special case. If we add any more non-stream ciphers, we | |||||
| // should invest a cleaner way to do this. | |||||
| gcmCipherID: {16, 12, 0, nil}, | |||||
| // insecure cipher, see http://www.isg.rhul.ac.uk/~kp/SandPfinal.pdf | |||||
| // uncomment below to enable it. | |||||
| // aes128cbcID: {16, aes.BlockSize, 0, nil}, | |||||
| } | |||||
| // prefixLen is the length of the packet prefix that contains the packet length | |||||
| // and number of padding bytes. | |||||
| const prefixLen = 5 | |||||
| // streamPacketCipher is a packetCipher using a stream cipher. | |||||
| type streamPacketCipher struct { | |||||
| mac hash.Hash | |||||
| cipher cipher.Stream | |||||
| // The following members are to avoid per-packet allocations. | |||||
| prefix [prefixLen]byte | |||||
| seqNumBytes [4]byte | |||||
| padding [2 * packetSizeMultiple]byte | |||||
| packetData []byte | |||||
| macResult []byte | |||||
| } | |||||
| // readPacket reads and decrypt a single packet from the reader argument. | |||||
| func (s *streamPacketCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { | |||||
| if _, err := io.ReadFull(r, s.prefix[:]); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) | |||||
| length := binary.BigEndian.Uint32(s.prefix[0:4]) | |||||
| paddingLength := uint32(s.prefix[4]) | |||||
| var macSize uint32 | |||||
| if s.mac != nil { | |||||
| s.mac.Reset() | |||||
| binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) | |||||
| s.mac.Write(s.seqNumBytes[:]) | |||||
| s.mac.Write(s.prefix[:]) | |||||
| macSize = uint32(s.mac.Size()) | |||||
| } | |||||
| if length <= paddingLength+1 { | |||||
| return nil, errors.New("ssh: invalid packet length, packet too small") | |||||
| } | |||||
| if length > maxPacket { | |||||
| return nil, errors.New("ssh: invalid packet length, packet too large") | |||||
| } | |||||
| // the maxPacket check above ensures that length-1+macSize | |||||
| // does not overflow. | |||||
| if uint32(cap(s.packetData)) < length-1+macSize { | |||||
| s.packetData = make([]byte, length-1+macSize) | |||||
| } else { | |||||
| s.packetData = s.packetData[:length-1+macSize] | |||||
| } | |||||
| if _, err := io.ReadFull(r, s.packetData); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| mac := s.packetData[length-1:] | |||||
| data := s.packetData[:length-1] | |||||
| s.cipher.XORKeyStream(data, data) | |||||
| if s.mac != nil { | |||||
| s.mac.Write(data) | |||||
| s.macResult = s.mac.Sum(s.macResult[:0]) | |||||
| if subtle.ConstantTimeCompare(s.macResult, mac) != 1 { | |||||
| return nil, errors.New("ssh: MAC failure") | |||||
| } | |||||
| } | |||||
| return s.packetData[:length-paddingLength-1], nil | |||||
| } | |||||
| // writePacket encrypts and sends a packet of data to the writer argument | |||||
| func (s *streamPacketCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { | |||||
| if len(packet) > maxPacket { | |||||
| return errors.New("ssh: packet too large") | |||||
| } | |||||
| paddingLength := packetSizeMultiple - (prefixLen+len(packet))%packetSizeMultiple | |||||
| if paddingLength < 4 { | |||||
| paddingLength += packetSizeMultiple | |||||
| } | |||||
| length := len(packet) + 1 + paddingLength | |||||
| binary.BigEndian.PutUint32(s.prefix[:], uint32(length)) | |||||
| s.prefix[4] = byte(paddingLength) | |||||
| padding := s.padding[:paddingLength] | |||||
| if _, err := io.ReadFull(rand, padding); err != nil { | |||||
| return err | |||||
| } | |||||
| if s.mac != nil { | |||||
| s.mac.Reset() | |||||
| binary.BigEndian.PutUint32(s.seqNumBytes[:], seqNum) | |||||
| s.mac.Write(s.seqNumBytes[:]) | |||||
| s.mac.Write(s.prefix[:]) | |||||
| s.mac.Write(packet) | |||||
| s.mac.Write(padding) | |||||
| } | |||||
| s.cipher.XORKeyStream(s.prefix[:], s.prefix[:]) | |||||
| s.cipher.XORKeyStream(packet, packet) | |||||
| s.cipher.XORKeyStream(padding, padding) | |||||
| if _, err := w.Write(s.prefix[:]); err != nil { | |||||
| return err | |||||
| } | |||||
| if _, err := w.Write(packet); err != nil { | |||||
| return err | |||||
| } | |||||
| if _, err := w.Write(padding); err != nil { | |||||
| return err | |||||
| } | |||||
| if s.mac != nil { | |||||
| s.macResult = s.mac.Sum(s.macResult[:0]) | |||||
| if _, err := w.Write(s.macResult); err != nil { | |||||
| return err | |||||
| } | |||||
| } | |||||
| return nil | |||||
| } | |||||
| type gcmCipher struct { | |||||
| aead cipher.AEAD | |||||
| prefix [4]byte | |||||
| iv []byte | |||||
| buf []byte | |||||
| } | |||||
| func newGCMCipher(iv, key, macKey []byte) (packetCipher, error) { | |||||
| c, err := aes.NewCipher(key) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| aead, err := cipher.NewGCM(c) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return &gcmCipher{ | |||||
| aead: aead, | |||||
| iv: iv, | |||||
| }, nil | |||||
| } | |||||
| const gcmTagSize = 16 | |||||
| func (c *gcmCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { | |||||
| // Pad out to multiple of 16 bytes. This is different from the | |||||
| // stream cipher because that encrypts the length too. | |||||
| padding := byte(packetSizeMultiple - (1+len(packet))%packetSizeMultiple) | |||||
| if padding < 4 { | |||||
| padding += packetSizeMultiple | |||||
| } | |||||
| length := uint32(len(packet) + int(padding) + 1) | |||||
| binary.BigEndian.PutUint32(c.prefix[:], length) | |||||
| if _, err := w.Write(c.prefix[:]); err != nil { | |||||
| return err | |||||
| } | |||||
| if cap(c.buf) < int(length) { | |||||
| c.buf = make([]byte, length) | |||||
| } else { | |||||
| c.buf = c.buf[:length] | |||||
| } | |||||
| c.buf[0] = padding | |||||
| copy(c.buf[1:], packet) | |||||
| if _, err := io.ReadFull(rand, c.buf[1+len(packet):]); err != nil { | |||||
| return err | |||||
| } | |||||
| c.buf = c.aead.Seal(c.buf[:0], c.iv, c.buf, c.prefix[:]) | |||||
| if _, err := w.Write(c.buf); err != nil { | |||||
| return err | |||||
| } | |||||
| c.incIV() | |||||
| return nil | |||||
| } | |||||
| func (c *gcmCipher) incIV() { | |||||
| for i := 4 + 7; i >= 4; i-- { | |||||
| c.iv[i]++ | |||||
| if c.iv[i] != 0 { | |||||
| break | |||||
| } | |||||
| } | |||||
| } | |||||
| func (c *gcmCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { | |||||
| if _, err := io.ReadFull(r, c.prefix[:]); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| length := binary.BigEndian.Uint32(c.prefix[:]) | |||||
| if length > maxPacket { | |||||
| return nil, errors.New("ssh: max packet length exceeded.") | |||||
| } | |||||
| if cap(c.buf) < int(length+gcmTagSize) { | |||||
| c.buf = make([]byte, length+gcmTagSize) | |||||
| } else { | |||||
| c.buf = c.buf[:length+gcmTagSize] | |||||
| } | |||||
| if _, err := io.ReadFull(r, c.buf); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| plain, err := c.aead.Open(c.buf[:0], c.iv, c.buf, c.prefix[:]) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| c.incIV() | |||||
| padding := plain[0] | |||||
| if padding < 4 || padding >= 20 { | |||||
| return nil, fmt.Errorf("ssh: illegal padding %d", padding) | |||||
| } | |||||
| if int(padding+1) >= len(plain) { | |||||
| return nil, fmt.Errorf("ssh: padding %d too large", padding) | |||||
| } | |||||
| plain = plain[1 : length-uint32(padding)] | |||||
| return plain, nil | |||||
| } | |||||
| // cbcCipher implements aes128-cbc cipher defined in RFC 4253 section 6.1 | |||||
| type cbcCipher struct { | |||||
| mac hash.Hash | |||||
| macSize uint32 | |||||
| decrypter cipher.BlockMode | |||||
| encrypter cipher.BlockMode | |||||
| // The following members are to avoid per-packet allocations. | |||||
| seqNumBytes [4]byte | |||||
| packetData []byte | |||||
| macResult []byte | |||||
| // Amount of data we should still read to hide which | |||||
| // verification error triggered. | |||||
| oracleCamouflage uint32 | |||||
| } | |||||
| func newAESCBCCipher(iv, key, macKey []byte, algs directionAlgorithms) (packetCipher, error) { | |||||
| c, err := aes.NewCipher(key) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| cbc := &cbcCipher{ | |||||
| mac: macModes[algs.MAC].new(macKey), | |||||
| decrypter: cipher.NewCBCDecrypter(c, iv), | |||||
| encrypter: cipher.NewCBCEncrypter(c, iv), | |||||
| packetData: make([]byte, 1024), | |||||
| } | |||||
| if cbc.mac != nil { | |||||
| cbc.macSize = uint32(cbc.mac.Size()) | |||||
| } | |||||
| return cbc, nil | |||||
| } | |||||
| func maxUInt32(a, b int) uint32 { | |||||
| if a > b { | |||||
| return uint32(a) | |||||
| } | |||||
| return uint32(b) | |||||
| } | |||||
| const ( | |||||
| cbcMinPacketSizeMultiple = 8 | |||||
| cbcMinPacketSize = 16 | |||||
| cbcMinPaddingSize = 4 | |||||
| ) | |||||
| // cbcError represents a verification error that may leak information. | |||||
| type cbcError string | |||||
| func (e cbcError) Error() string { return string(e) } | |||||
| func (c *cbcCipher) readPacket(seqNum uint32, r io.Reader) ([]byte, error) { | |||||
| p, err := c.readPacketLeaky(seqNum, r) | |||||
| if err != nil { | |||||
| if _, ok := err.(cbcError); ok { | |||||
| // Verification error: read a fixed amount of | |||||
| // data, to make distinguishing between | |||||
| // failing MAC and failing length check more | |||||
| // difficult. | |||||
| io.CopyN(ioutil.Discard, r, int64(c.oracleCamouflage)) | |||||
| } | |||||
| } | |||||
| return p, err | |||||
| } | |||||
| func (c *cbcCipher) readPacketLeaky(seqNum uint32, r io.Reader) ([]byte, error) { | |||||
| blockSize := c.decrypter.BlockSize() | |||||
| // Read the header, which will include some of the subsequent data in the | |||||
| // case of block ciphers - this is copied back to the payload later. | |||||
| // How many bytes of payload/padding will be read with this first read. | |||||
| firstBlockLength := uint32((prefixLen + blockSize - 1) / blockSize * blockSize) | |||||
| firstBlock := c.packetData[:firstBlockLength] | |||||
| if _, err := io.ReadFull(r, firstBlock); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| c.oracleCamouflage = maxPacket + 4 + c.macSize - firstBlockLength | |||||
| c.decrypter.CryptBlocks(firstBlock, firstBlock) | |||||
| length := binary.BigEndian.Uint32(firstBlock[:4]) | |||||
| if length > maxPacket { | |||||
| return nil, cbcError("ssh: packet too large") | |||||
| } | |||||
| if length+4 < maxUInt32(cbcMinPacketSize, blockSize) { | |||||
| // The minimum size of a packet is 16 (or the cipher block size, whichever | |||||
| // is larger) bytes. | |||||
| return nil, cbcError("ssh: packet too small") | |||||
| } | |||||
| // The length of the packet (including the length field but not the MAC) must | |||||
| // be a multiple of the block size or 8, whichever is larger. | |||||
| if (length+4)%maxUInt32(cbcMinPacketSizeMultiple, blockSize) != 0 { | |||||
| return nil, cbcError("ssh: invalid packet length multiple") | |||||
| } | |||||
| paddingLength := uint32(firstBlock[4]) | |||||
| if paddingLength < cbcMinPaddingSize || length <= paddingLength+1 { | |||||
| return nil, cbcError("ssh: invalid packet length") | |||||
| } | |||||
| // Positions within the c.packetData buffer: | |||||
| macStart := 4 + length | |||||
| paddingStart := macStart - paddingLength | |||||
| // Entire packet size, starting before length, ending at end of mac. | |||||
| entirePacketSize := macStart + c.macSize | |||||
| // Ensure c.packetData is large enough for the entire packet data. | |||||
| if uint32(cap(c.packetData)) < entirePacketSize { | |||||
| // Still need to upsize and copy, but this should be rare at runtime, only | |||||
| // on upsizing the packetData buffer. | |||||
| c.packetData = make([]byte, entirePacketSize) | |||||
| copy(c.packetData, firstBlock) | |||||
| } else { | |||||
| c.packetData = c.packetData[:entirePacketSize] | |||||
| } | |||||
| if n, err := io.ReadFull(r, c.packetData[firstBlockLength:]); err != nil { | |||||
| return nil, err | |||||
| } else { | |||||
| c.oracleCamouflage -= uint32(n) | |||||
| } | |||||
| remainingCrypted := c.packetData[firstBlockLength:macStart] | |||||
| c.decrypter.CryptBlocks(remainingCrypted, remainingCrypted) | |||||
| mac := c.packetData[macStart:] | |||||
| if c.mac != nil { | |||||
| c.mac.Reset() | |||||
| binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) | |||||
| c.mac.Write(c.seqNumBytes[:]) | |||||
| c.mac.Write(c.packetData[:macStart]) | |||||
| c.macResult = c.mac.Sum(c.macResult[:0]) | |||||
| if subtle.ConstantTimeCompare(c.macResult, mac) != 1 { | |||||
| return nil, cbcError("ssh: MAC failure") | |||||
| } | |||||
| } | |||||
| return c.packetData[prefixLen:paddingStart], nil | |||||
| } | |||||
| func (c *cbcCipher) writePacket(seqNum uint32, w io.Writer, rand io.Reader, packet []byte) error { | |||||
| effectiveBlockSize := maxUInt32(cbcMinPacketSizeMultiple, c.encrypter.BlockSize()) | |||||
| // Length of encrypted portion of the packet (header, payload, padding). | |||||
| // Enforce minimum padding and packet size. | |||||
| encLength := maxUInt32(prefixLen+len(packet)+cbcMinPaddingSize, cbcMinPaddingSize) | |||||
| // Enforce block size. | |||||
| encLength = (encLength + effectiveBlockSize - 1) / effectiveBlockSize * effectiveBlockSize | |||||
| length := encLength - 4 | |||||
| paddingLength := int(length) - (1 + len(packet)) | |||||
| // Overall buffer contains: header, payload, padding, mac. | |||||
| // Space for the MAC is reserved in the capacity but not the slice length. | |||||
| bufferSize := encLength + c.macSize | |||||
| if uint32(cap(c.packetData)) < bufferSize { | |||||
| c.packetData = make([]byte, encLength, bufferSize) | |||||
| } else { | |||||
| c.packetData = c.packetData[:encLength] | |||||
| } | |||||
| p := c.packetData | |||||
| // Packet header. | |||||
| binary.BigEndian.PutUint32(p, length) | |||||
| p = p[4:] | |||||
| p[0] = byte(paddingLength) | |||||
| // Payload. | |||||
| p = p[1:] | |||||
| copy(p, packet) | |||||
| // Padding. | |||||
| p = p[len(packet):] | |||||
| if _, err := io.ReadFull(rand, p); err != nil { | |||||
| return err | |||||
| } | |||||
| if c.mac != nil { | |||||
| c.mac.Reset() | |||||
| binary.BigEndian.PutUint32(c.seqNumBytes[:], seqNum) | |||||
| c.mac.Write(c.seqNumBytes[:]) | |||||
| c.mac.Write(c.packetData) | |||||
| // The MAC is now appended into the capacity reserved for it earlier. | |||||
| c.packetData = c.mac.Sum(c.packetData) | |||||
| } | |||||
| c.encrypter.CryptBlocks(c.packetData[:encLength], c.packetData[:encLength]) | |||||
| if _, err := w.Write(c.packetData); err != nil { | |||||
| return err | |||||
| } | |||||
| return nil | |||||
| } | |||||
| @@ -1,127 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "bytes" | |||||
| "crypto" | |||||
| "crypto/aes" | |||||
| "crypto/rand" | |||||
| "testing" | |||||
| ) | |||||
| func TestDefaultCiphersExist(t *testing.T) { | |||||
| for _, cipherAlgo := range supportedCiphers { | |||||
| if _, ok := cipherModes[cipherAlgo]; !ok { | |||||
| t.Errorf("default cipher %q is unknown", cipherAlgo) | |||||
| } | |||||
| } | |||||
| } | |||||
| func TestPacketCiphers(t *testing.T) { | |||||
| // Still test aes128cbc cipher althought it's commented out. | |||||
| cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil} | |||||
| defer delete(cipherModes, aes128cbcID) | |||||
| for cipher := range cipherModes { | |||||
| kr := &kexResult{Hash: crypto.SHA1} | |||||
| algs := directionAlgorithms{ | |||||
| Cipher: cipher, | |||||
| MAC: "hmac-sha1", | |||||
| Compression: "none", | |||||
| } | |||||
| client, err := newPacketCipher(clientKeys, algs, kr) | |||||
| if err != nil { | |||||
| t.Errorf("newPacketCipher(client, %q): %v", cipher, err) | |||||
| continue | |||||
| } | |||||
| server, err := newPacketCipher(clientKeys, algs, kr) | |||||
| if err != nil { | |||||
| t.Errorf("newPacketCipher(client, %q): %v", cipher, err) | |||||
| continue | |||||
| } | |||||
| want := "bla bla" | |||||
| input := []byte(want) | |||||
| buf := &bytes.Buffer{} | |||||
| if err := client.writePacket(0, buf, rand.Reader, input); err != nil { | |||||
| t.Errorf("writePacket(%q): %v", cipher, err) | |||||
| continue | |||||
| } | |||||
| packet, err := server.readPacket(0, buf) | |||||
| if err != nil { | |||||
| t.Errorf("readPacket(%q): %v", cipher, err) | |||||
| continue | |||||
| } | |||||
| if string(packet) != want { | |||||
| t.Errorf("roundtrip(%q): got %q, want %q", cipher, packet, want) | |||||
| } | |||||
| } | |||||
| } | |||||
| func TestCBCOracleCounterMeasure(t *testing.T) { | |||||
| cipherModes[aes128cbcID] = &streamCipherMode{16, aes.BlockSize, 0, nil} | |||||
| defer delete(cipherModes, aes128cbcID) | |||||
| kr := &kexResult{Hash: crypto.SHA1} | |||||
| algs := directionAlgorithms{ | |||||
| Cipher: aes128cbcID, | |||||
| MAC: "hmac-sha1", | |||||
| Compression: "none", | |||||
| } | |||||
| client, err := newPacketCipher(clientKeys, algs, kr) | |||||
| if err != nil { | |||||
| t.Fatalf("newPacketCipher(client): %v", err) | |||||
| } | |||||
| want := "bla bla" | |||||
| input := []byte(want) | |||||
| buf := &bytes.Buffer{} | |||||
| if err := client.writePacket(0, buf, rand.Reader, input); err != nil { | |||||
| t.Errorf("writePacket: %v", err) | |||||
| } | |||||
| packetSize := buf.Len() | |||||
| buf.Write(make([]byte, 2*maxPacket)) | |||||
| // We corrupt each byte, but this usually will only test the | |||||
| // 'packet too large' or 'MAC failure' cases. | |||||
| lastRead := -1 | |||||
| for i := 0; i < packetSize; i++ { | |||||
| server, err := newPacketCipher(clientKeys, algs, kr) | |||||
| if err != nil { | |||||
| t.Fatalf("newPacketCipher(client): %v", err) | |||||
| } | |||||
| fresh := &bytes.Buffer{} | |||||
| fresh.Write(buf.Bytes()) | |||||
| fresh.Bytes()[i] ^= 0x01 | |||||
| before := fresh.Len() | |||||
| _, err = server.readPacket(0, fresh) | |||||
| if err == nil { | |||||
| t.Errorf("corrupt byte %d: readPacket succeeded ", i) | |||||
| continue | |||||
| } | |||||
| if _, ok := err.(cbcError); !ok { | |||||
| t.Errorf("corrupt byte %d: got %v (%T), want cbcError", i, err, err) | |||||
| continue | |||||
| } | |||||
| after := fresh.Len() | |||||
| bytesRead := before - after | |||||
| if bytesRead < maxPacket { | |||||
| t.Errorf("corrupt byte %d: read %d bytes, want more than %d", i, bytesRead, maxPacket) | |||||
| continue | |||||
| } | |||||
| if i > 0 && bytesRead != lastRead { | |||||
| t.Errorf("corrupt byte %d: read %d bytes, want %d bytes read", i, bytesRead, lastRead) | |||||
| } | |||||
| lastRead = bytesRead | |||||
| } | |||||
| } | |||||
| @@ -1,213 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "errors" | |||||
| "fmt" | |||||
| "net" | |||||
| "sync" | |||||
| ) | |||||
| // Client implements a traditional SSH client that supports shells, | |||||
| // subprocesses, port forwarding and tunneled dialing. | |||||
| type Client struct { | |||||
| Conn | |||||
| forwards forwardList // forwarded tcpip connections from the remote side | |||||
| mu sync.Mutex | |||||
| channelHandlers map[string]chan NewChannel | |||||
| } | |||||
| // HandleChannelOpen returns a channel on which NewChannel requests | |||||
| // for the given type are sent. If the type already is being handled, | |||||
| // nil is returned. The channel is closed when the connection is closed. | |||||
| func (c *Client) HandleChannelOpen(channelType string) <-chan NewChannel { | |||||
| c.mu.Lock() | |||||
| defer c.mu.Unlock() | |||||
| if c.channelHandlers == nil { | |||||
| // The SSH channel has been closed. | |||||
| c := make(chan NewChannel) | |||||
| close(c) | |||||
| return c | |||||
| } | |||||
| ch := c.channelHandlers[channelType] | |||||
| if ch != nil { | |||||
| return nil | |||||
| } | |||||
| ch = make(chan NewChannel, 16) | |||||
| c.channelHandlers[channelType] = ch | |||||
| return ch | |||||
| } | |||||
| // NewClient creates a Client on top of the given connection. | |||||
| func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client { | |||||
| conn := &Client{ | |||||
| Conn: c, | |||||
| channelHandlers: make(map[string]chan NewChannel, 1), | |||||
| } | |||||
| go conn.handleGlobalRequests(reqs) | |||||
| go conn.handleChannelOpens(chans) | |||||
| go func() { | |||||
| conn.Wait() | |||||
| conn.forwards.closeAll() | |||||
| }() | |||||
| go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip")) | |||||
| return conn | |||||
| } | |||||
| // NewClientConn establishes an authenticated SSH connection using c | |||||
| // as the underlying transport. The Request and NewChannel channels | |||||
| // must be serviced or the connection will hang. | |||||
| func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) { | |||||
| fullConf := *config | |||||
| fullConf.SetDefaults() | |||||
| conn := &connection{ | |||||
| sshConn: sshConn{conn: c}, | |||||
| } | |||||
| if err := conn.clientHandshake(addr, &fullConf); err != nil { | |||||
| c.Close() | |||||
| return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err) | |||||
| } | |||||
| conn.mux = newMux(conn.transport) | |||||
| return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil | |||||
| } | |||||
| // clientHandshake performs the client side key exchange. See RFC 4253 Section | |||||
| // 7. | |||||
| func (c *connection) clientHandshake(dialAddress string, config *ClientConfig) error { | |||||
| if config.ClientVersion != "" { | |||||
| c.clientVersion = []byte(config.ClientVersion) | |||||
| } else { | |||||
| c.clientVersion = []byte(packageVersion) | |||||
| } | |||||
| var err error | |||||
| c.serverVersion, err = exchangeVersions(c.sshConn.conn, c.clientVersion) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| c.transport = newClientTransport( | |||||
| newTransport(c.sshConn.conn, config.Rand, true /* is client */), | |||||
| c.clientVersion, c.serverVersion, config, dialAddress, c.sshConn.RemoteAddr()) | |||||
| if err := c.transport.requestKeyChange(); err != nil { | |||||
| return err | |||||
| } | |||||
| if packet, err := c.transport.readPacket(); err != nil { | |||||
| return err | |||||
| } else if packet[0] != msgNewKeys { | |||||
| return unexpectedMessageError(msgNewKeys, packet[0]) | |||||
| } | |||||
| // We just did the key change, so the session ID is established. | |||||
| c.sessionID = c.transport.getSessionID() | |||||
| return c.clientAuthenticate(config) | |||||
| } | |||||
| // verifyHostKeySignature verifies the host key obtained in the key | |||||
| // exchange. | |||||
| func verifyHostKeySignature(hostKey PublicKey, result *kexResult) error { | |||||
| sig, rest, ok := parseSignatureBody(result.Signature) | |||||
| if len(rest) > 0 || !ok { | |||||
| return errors.New("ssh: signature parse error") | |||||
| } | |||||
| return hostKey.Verify(result.H, sig) | |||||
| } | |||||
| // NewSession opens a new Session for this client. (A session is a remote | |||||
| // execution of a program.) | |||||
| func (c *Client) NewSession() (*Session, error) { | |||||
| ch, in, err := c.OpenChannel("session", nil) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return newSession(ch, in) | |||||
| } | |||||
| func (c *Client) handleGlobalRequests(incoming <-chan *Request) { | |||||
| for r := range incoming { | |||||
| // This handles keepalive messages and matches | |||||
| // the behaviour of OpenSSH. | |||||
| r.Reply(false, nil) | |||||
| } | |||||
| } | |||||
| // handleChannelOpens channel open messages from the remote side. | |||||
| func (c *Client) handleChannelOpens(in <-chan NewChannel) { | |||||
| for ch := range in { | |||||
| c.mu.Lock() | |||||
| handler := c.channelHandlers[ch.ChannelType()] | |||||
| c.mu.Unlock() | |||||
| if handler != nil { | |||||
| handler <- ch | |||||
| } else { | |||||
| ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type: %v", ch.ChannelType())) | |||||
| } | |||||
| } | |||||
| c.mu.Lock() | |||||
| for _, ch := range c.channelHandlers { | |||||
| close(ch) | |||||
| } | |||||
| c.channelHandlers = nil | |||||
| c.mu.Unlock() | |||||
| } | |||||
| // Dial starts a client connection to the given SSH server. It is a | |||||
| // convenience function that connects to the given network address, | |||||
| // initiates the SSH handshake, and then sets up a Client. For access | |||||
| // to incoming channels and requests, use net.Dial with NewClientConn | |||||
| // instead. | |||||
| func Dial(network, addr string, config *ClientConfig) (*Client, error) { | |||||
| conn, err := net.Dial(network, addr) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| c, chans, reqs, err := NewClientConn(conn, addr, config) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return NewClient(c, chans, reqs), nil | |||||
| } | |||||
| // A ClientConfig structure is used to configure a Client. It must not be | |||||
| // modified after having been passed to an SSH function. | |||||
| type ClientConfig struct { | |||||
| // Config contains configuration that is shared between clients and | |||||
| // servers. | |||||
| Config | |||||
| // User contains the username to authenticate as. | |||||
| User string | |||||
| // Auth contains possible authentication methods to use with the | |||||
| // server. Only the first instance of a particular RFC 4252 method will | |||||
| // be used during authentication. | |||||
| Auth []AuthMethod | |||||
| // HostKeyCallback, if not nil, is called during the cryptographic | |||||
| // handshake to validate the server's host key. A nil HostKeyCallback | |||||
| // implies that all host keys are accepted. | |||||
| HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error | |||||
| // ClientVersion contains the version identification string that will | |||||
| // be used for the connection. If empty, a reasonable default is used. | |||||
| ClientVersion string | |||||
| // HostKeyAlgorithms lists the key types that the client will | |||||
| // accept from the server as host key, in order of | |||||
| // preference. If empty, a reasonable default is used. Any | |||||
| // string returned from PublicKey.Type method may be used, or | |||||
| // any of the CertAlgoXxxx and KeyAlgoXxxx constants. | |||||
| HostKeyAlgorithms []string | |||||
| } | |||||
| @@ -1,441 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "bytes" | |||||
| "errors" | |||||
| "fmt" | |||||
| "io" | |||||
| ) | |||||
| // clientAuthenticate authenticates with the remote server. See RFC 4252. | |||||
| func (c *connection) clientAuthenticate(config *ClientConfig) error { | |||||
| // initiate user auth session | |||||
| if err := c.transport.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})); err != nil { | |||||
| return err | |||||
| } | |||||
| packet, err := c.transport.readPacket() | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| var serviceAccept serviceAcceptMsg | |||||
| if err := Unmarshal(packet, &serviceAccept); err != nil { | |||||
| return err | |||||
| } | |||||
| // during the authentication phase the client first attempts the "none" method | |||||
| // then any untried methods suggested by the server. | |||||
| tried := make(map[string]bool) | |||||
| var lastMethods []string | |||||
| for auth := AuthMethod(new(noneAuth)); auth != nil; { | |||||
| ok, methods, err := auth.auth(c.transport.getSessionID(), config.User, c.transport, config.Rand) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| if ok { | |||||
| // success | |||||
| return nil | |||||
| } | |||||
| tried[auth.method()] = true | |||||
| if methods == nil { | |||||
| methods = lastMethods | |||||
| } | |||||
| lastMethods = methods | |||||
| auth = nil | |||||
| findNext: | |||||
| for _, a := range config.Auth { | |||||
| candidateMethod := a.method() | |||||
| if tried[candidateMethod] { | |||||
| continue | |||||
| } | |||||
| for _, meth := range methods { | |||||
| if meth == candidateMethod { | |||||
| auth = a | |||||
| break findNext | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", keys(tried)) | |||||
| } | |||||
| func keys(m map[string]bool) []string { | |||||
| s := make([]string, 0, len(m)) | |||||
| for key := range m { | |||||
| s = append(s, key) | |||||
| } | |||||
| return s | |||||
| } | |||||
| // An AuthMethod represents an instance of an RFC 4252 authentication method. | |||||
| type AuthMethod interface { | |||||
| // auth authenticates user over transport t. | |||||
| // Returns true if authentication is successful. | |||||
| // If authentication is not successful, a []string of alternative | |||||
| // method names is returned. If the slice is nil, it will be ignored | |||||
| // and the previous set of possible methods will be reused. | |||||
| auth(session []byte, user string, p packetConn, rand io.Reader) (bool, []string, error) | |||||
| // method returns the RFC 4252 method name. | |||||
| method() string | |||||
| } | |||||
| // "none" authentication, RFC 4252 section 5.2. | |||||
| type noneAuth int | |||||
| func (n *noneAuth) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { | |||||
| if err := c.writePacket(Marshal(&userAuthRequestMsg{ | |||||
| User: user, | |||||
| Service: serviceSSH, | |||||
| Method: "none", | |||||
| })); err != nil { | |||||
| return false, nil, err | |||||
| } | |||||
| return handleAuthResponse(c) | |||||
| } | |||||
| func (n *noneAuth) method() string { | |||||
| return "none" | |||||
| } | |||||
| // passwordCallback is an AuthMethod that fetches the password through | |||||
| // a function call, e.g. by prompting the user. | |||||
| type passwordCallback func() (password string, err error) | |||||
| func (cb passwordCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { | |||||
| type passwordAuthMsg struct { | |||||
| User string `sshtype:"50"` | |||||
| Service string | |||||
| Method string | |||||
| Reply bool | |||||
| Password string | |||||
| } | |||||
| pw, err := cb() | |||||
| // REVIEW NOTE: is there a need to support skipping a password attempt? | |||||
| // The program may only find out that the user doesn't have a password | |||||
| // when prompting. | |||||
| if err != nil { | |||||
| return false, nil, err | |||||
| } | |||||
| if err := c.writePacket(Marshal(&passwordAuthMsg{ | |||||
| User: user, | |||||
| Service: serviceSSH, | |||||
| Method: cb.method(), | |||||
| Reply: false, | |||||
| Password: pw, | |||||
| })); err != nil { | |||||
| return false, nil, err | |||||
| } | |||||
| return handleAuthResponse(c) | |||||
| } | |||||
| func (cb passwordCallback) method() string { | |||||
| return "password" | |||||
| } | |||||
| // Password returns an AuthMethod using the given password. | |||||
| func Password(secret string) AuthMethod { | |||||
| return passwordCallback(func() (string, error) { return secret, nil }) | |||||
| } | |||||
| // PasswordCallback returns an AuthMethod that uses a callback for | |||||
| // fetching a password. | |||||
| func PasswordCallback(prompt func() (secret string, err error)) AuthMethod { | |||||
| return passwordCallback(prompt) | |||||
| } | |||||
| type publickeyAuthMsg struct { | |||||
| User string `sshtype:"50"` | |||||
| Service string | |||||
| Method string | |||||
| // HasSig indicates to the receiver packet that the auth request is signed and | |||||
| // should be used for authentication of the request. | |||||
| HasSig bool | |||||
| Algoname string | |||||
| PubKey []byte | |||||
| // Sig is tagged with "rest" so Marshal will exclude it during | |||||
| // validateKey | |||||
| Sig []byte `ssh:"rest"` | |||||
| } | |||||
| // publicKeyCallback is an AuthMethod that uses a set of key | |||||
| // pairs for authentication. | |||||
| type publicKeyCallback func() ([]Signer, error) | |||||
| func (cb publicKeyCallback) method() string { | |||||
| return "publickey" | |||||
| } | |||||
| func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { | |||||
| // Authentication is performed in two stages. The first stage sends an | |||||
| // enquiry to test if each key is acceptable to the remote. The second | |||||
| // stage attempts to authenticate with the valid keys obtained in the | |||||
| // first stage. | |||||
| signers, err := cb() | |||||
| if err != nil { | |||||
| return false, nil, err | |||||
| } | |||||
| var validKeys []Signer | |||||
| for _, signer := range signers { | |||||
| if ok, err := validateKey(signer.PublicKey(), user, c); ok { | |||||
| validKeys = append(validKeys, signer) | |||||
| } else { | |||||
| if err != nil { | |||||
| return false, nil, err | |||||
| } | |||||
| } | |||||
| } | |||||
| // methods that may continue if this auth is not successful. | |||||
| var methods []string | |||||
| for _, signer := range validKeys { | |||||
| pub := signer.PublicKey() | |||||
| pubKey := pub.Marshal() | |||||
| sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{ | |||||
| User: user, | |||||
| Service: serviceSSH, | |||||
| Method: cb.method(), | |||||
| }, []byte(pub.Type()), pubKey)) | |||||
| if err != nil { | |||||
| return false, nil, err | |||||
| } | |||||
| // manually wrap the serialized signature in a string | |||||
| s := Marshal(sign) | |||||
| sig := make([]byte, stringLength(len(s))) | |||||
| marshalString(sig, s) | |||||
| msg := publickeyAuthMsg{ | |||||
| User: user, | |||||
| Service: serviceSSH, | |||||
| Method: cb.method(), | |||||
| HasSig: true, | |||||
| Algoname: pub.Type(), | |||||
| PubKey: pubKey, | |||||
| Sig: sig, | |||||
| } | |||||
| p := Marshal(&msg) | |||||
| if err := c.writePacket(p); err != nil { | |||||
| return false, nil, err | |||||
| } | |||||
| var success bool | |||||
| success, methods, err = handleAuthResponse(c) | |||||
| if err != nil { | |||||
| return false, nil, err | |||||
| } | |||||
| if success { | |||||
| return success, methods, err | |||||
| } | |||||
| } | |||||
| return false, methods, nil | |||||
| } | |||||
| // validateKey validates the key provided is acceptable to the server. | |||||
| func validateKey(key PublicKey, user string, c packetConn) (bool, error) { | |||||
| pubKey := key.Marshal() | |||||
| msg := publickeyAuthMsg{ | |||||
| User: user, | |||||
| Service: serviceSSH, | |||||
| Method: "publickey", | |||||
| HasSig: false, | |||||
| Algoname: key.Type(), | |||||
| PubKey: pubKey, | |||||
| } | |||||
| if err := c.writePacket(Marshal(&msg)); err != nil { | |||||
| return false, err | |||||
| } | |||||
| return confirmKeyAck(key, c) | |||||
| } | |||||
| func confirmKeyAck(key PublicKey, c packetConn) (bool, error) { | |||||
| pubKey := key.Marshal() | |||||
| algoname := key.Type() | |||||
| for { | |||||
| packet, err := c.readPacket() | |||||
| if err != nil { | |||||
| return false, err | |||||
| } | |||||
| switch packet[0] { | |||||
| case msgUserAuthBanner: | |||||
| // TODO(gpaul): add callback to present the banner to the user | |||||
| case msgUserAuthPubKeyOk: | |||||
| var msg userAuthPubKeyOkMsg | |||||
| if err := Unmarshal(packet, &msg); err != nil { | |||||
| return false, err | |||||
| } | |||||
| if msg.Algo != algoname || !bytes.Equal(msg.PubKey, pubKey) { | |||||
| return false, nil | |||||
| } | |||||
| return true, nil | |||||
| case msgUserAuthFailure: | |||||
| return false, nil | |||||
| default: | |||||
| return false, unexpectedMessageError(msgUserAuthSuccess, packet[0]) | |||||
| } | |||||
| } | |||||
| } | |||||
| // PublicKeys returns an AuthMethod that uses the given key | |||||
| // pairs. | |||||
| func PublicKeys(signers ...Signer) AuthMethod { | |||||
| return publicKeyCallback(func() ([]Signer, error) { return signers, nil }) | |||||
| } | |||||
| // PublicKeysCallback returns an AuthMethod that runs the given | |||||
| // function to obtain a list of key pairs. | |||||
| func PublicKeysCallback(getSigners func() (signers []Signer, err error)) AuthMethod { | |||||
| return publicKeyCallback(getSigners) | |||||
| } | |||||
| // handleAuthResponse returns whether the preceding authentication request succeeded | |||||
| // along with a list of remaining authentication methods to try next and | |||||
| // an error if an unexpected response was received. | |||||
| func handleAuthResponse(c packetConn) (bool, []string, error) { | |||||
| for { | |||||
| packet, err := c.readPacket() | |||||
| if err != nil { | |||||
| return false, nil, err | |||||
| } | |||||
| switch packet[0] { | |||||
| case msgUserAuthBanner: | |||||
| // TODO: add callback to present the banner to the user | |||||
| case msgUserAuthFailure: | |||||
| var msg userAuthFailureMsg | |||||
| if err := Unmarshal(packet, &msg); err != nil { | |||||
| return false, nil, err | |||||
| } | |||||
| return false, msg.Methods, nil | |||||
| case msgUserAuthSuccess: | |||||
| return true, nil, nil | |||||
| case msgDisconnect: | |||||
| return false, nil, io.EOF | |||||
| default: | |||||
| return false, nil, unexpectedMessageError(msgUserAuthSuccess, packet[0]) | |||||
| } | |||||
| } | |||||
| } | |||||
| // KeyboardInteractiveChallenge should print questions, optionally | |||||
| // disabling echoing (e.g. for passwords), and return all the answers. | |||||
| // Challenge may be called multiple times in a single session. After | |||||
| // successful authentication, the server may send a challenge with no | |||||
| // questions, for which the user and instruction messages should be | |||||
| // printed. RFC 4256 section 3.3 details how the UI should behave for | |||||
| // both CLI and GUI environments. | |||||
| type KeyboardInteractiveChallenge func(user, instruction string, questions []string, echos []bool) (answers []string, err error) | |||||
| // KeyboardInteractive returns a AuthMethod using a prompt/response | |||||
| // sequence controlled by the server. | |||||
| func KeyboardInteractive(challenge KeyboardInteractiveChallenge) AuthMethod { | |||||
| return challenge | |||||
| } | |||||
| func (cb KeyboardInteractiveChallenge) method() string { | |||||
| return "keyboard-interactive" | |||||
| } | |||||
| func (cb KeyboardInteractiveChallenge) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { | |||||
| type initiateMsg struct { | |||||
| User string `sshtype:"50"` | |||||
| Service string | |||||
| Method string | |||||
| Language string | |||||
| Submethods string | |||||
| } | |||||
| if err := c.writePacket(Marshal(&initiateMsg{ | |||||
| User: user, | |||||
| Service: serviceSSH, | |||||
| Method: "keyboard-interactive", | |||||
| })); err != nil { | |||||
| return false, nil, err | |||||
| } | |||||
| for { | |||||
| packet, err := c.readPacket() | |||||
| if err != nil { | |||||
| return false, nil, err | |||||
| } | |||||
| // like handleAuthResponse, but with less options. | |||||
| switch packet[0] { | |||||
| case msgUserAuthBanner: | |||||
| // TODO: Print banners during userauth. | |||||
| continue | |||||
| case msgUserAuthInfoRequest: | |||||
| // OK | |||||
| case msgUserAuthFailure: | |||||
| var msg userAuthFailureMsg | |||||
| if err := Unmarshal(packet, &msg); err != nil { | |||||
| return false, nil, err | |||||
| } | |||||
| return false, msg.Methods, nil | |||||
| case msgUserAuthSuccess: | |||||
| return true, nil, nil | |||||
| default: | |||||
| return false, nil, unexpectedMessageError(msgUserAuthInfoRequest, packet[0]) | |||||
| } | |||||
| var msg userAuthInfoRequestMsg | |||||
| if err := Unmarshal(packet, &msg); err != nil { | |||||
| return false, nil, err | |||||
| } | |||||
| // Manually unpack the prompt/echo pairs. | |||||
| rest := msg.Prompts | |||||
| var prompts []string | |||||
| var echos []bool | |||||
| for i := 0; i < int(msg.NumPrompts); i++ { | |||||
| prompt, r, ok := parseString(rest) | |||||
| if !ok || len(r) == 0 { | |||||
| return false, nil, errors.New("ssh: prompt format error") | |||||
| } | |||||
| prompts = append(prompts, string(prompt)) | |||||
| echos = append(echos, r[0] != 0) | |||||
| rest = r[1:] | |||||
| } | |||||
| if len(rest) != 0 { | |||||
| return false, nil, errors.New("ssh: extra data following keyboard-interactive pairs") | |||||
| } | |||||
| answers, err := cb(msg.User, msg.Instruction, prompts, echos) | |||||
| if err != nil { | |||||
| return false, nil, err | |||||
| } | |||||
| if len(answers) != len(prompts) { | |||||
| return false, nil, errors.New("ssh: not enough answers from keyboard-interactive callback") | |||||
| } | |||||
| responseLength := 1 + 4 | |||||
| for _, a := range answers { | |||||
| responseLength += stringLength(len(a)) | |||||
| } | |||||
| serialized := make([]byte, responseLength) | |||||
| p := serialized | |||||
| p[0] = msgUserAuthInfoResponse | |||||
| p = p[1:] | |||||
| p = marshalUint32(p, uint32(len(answers))) | |||||
| for _, a := range answers { | |||||
| p = marshalString(p, []byte(a)) | |||||
| } | |||||
| if err := c.writePacket(serialized); err != nil { | |||||
| return false, nil, err | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,393 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "bytes" | |||||
| "crypto/rand" | |||||
| "errors" | |||||
| "fmt" | |||||
| "strings" | |||||
| "testing" | |||||
| ) | |||||
| type keyboardInteractive map[string]string | |||||
| func (cr keyboardInteractive) Challenge(user string, instruction string, questions []string, echos []bool) ([]string, error) { | |||||
| var answers []string | |||||
| for _, q := range questions { | |||||
| answers = append(answers, cr[q]) | |||||
| } | |||||
| return answers, nil | |||||
| } | |||||
| // reused internally by tests | |||||
| var clientPassword = "tiger" | |||||
| // tryAuth runs a handshake with a given config against an SSH server | |||||
| // with config serverConfig | |||||
| func tryAuth(t *testing.T, config *ClientConfig) error { | |||||
| c1, c2, err := netPipe() | |||||
| if err != nil { | |||||
| t.Fatalf("netPipe: %v", err) | |||||
| } | |||||
| defer c1.Close() | |||||
| defer c2.Close() | |||||
| certChecker := CertChecker{ | |||||
| IsAuthority: func(k PublicKey) bool { | |||||
| return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal()) | |||||
| }, | |||||
| UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { | |||||
| if conn.User() == "testuser" && bytes.Equal(key.Marshal(), testPublicKeys["rsa"].Marshal()) { | |||||
| return nil, nil | |||||
| } | |||||
| return nil, fmt.Errorf("pubkey for %q not acceptable", conn.User()) | |||||
| }, | |||||
| IsRevoked: func(c *Certificate) bool { | |||||
| return c.Serial == 666 | |||||
| }, | |||||
| } | |||||
| serverConfig := &ServerConfig{ | |||||
| PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) { | |||||
| if conn.User() == "testuser" && string(pass) == clientPassword { | |||||
| return nil, nil | |||||
| } | |||||
| return nil, errors.New("password auth failed") | |||||
| }, | |||||
| PublicKeyCallback: certChecker.Authenticate, | |||||
| KeyboardInteractiveCallback: func(conn ConnMetadata, challenge KeyboardInteractiveChallenge) (*Permissions, error) { | |||||
| ans, err := challenge("user", | |||||
| "instruction", | |||||
| []string{"question1", "question2"}, | |||||
| []bool{true, true}) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| ok := conn.User() == "testuser" && ans[0] == "answer1" && ans[1] == "answer2" | |||||
| if ok { | |||||
| challenge("user", "motd", nil, nil) | |||||
| return nil, nil | |||||
| } | |||||
| return nil, errors.New("keyboard-interactive failed") | |||||
| }, | |||||
| AuthLogCallback: func(conn ConnMetadata, method string, err error) { | |||||
| t.Logf("user %q, method %q: %v", conn.User(), method, err) | |||||
| }, | |||||
| } | |||||
| serverConfig.AddHostKey(testSigners["rsa"]) | |||||
| go newServer(c1, serverConfig) | |||||
| _, _, _, err = NewClientConn(c2, "", config) | |||||
| return err | |||||
| } | |||||
| func TestClientAuthPublicKey(t *testing.T) { | |||||
| config := &ClientConfig{ | |||||
| User: "testuser", | |||||
| Auth: []AuthMethod{ | |||||
| PublicKeys(testSigners["rsa"]), | |||||
| }, | |||||
| } | |||||
| if err := tryAuth(t, config); err != nil { | |||||
| t.Fatalf("unable to dial remote side: %s", err) | |||||
| } | |||||
| } | |||||
| func TestAuthMethodPassword(t *testing.T) { | |||||
| config := &ClientConfig{ | |||||
| User: "testuser", | |||||
| Auth: []AuthMethod{ | |||||
| Password(clientPassword), | |||||
| }, | |||||
| } | |||||
| if err := tryAuth(t, config); err != nil { | |||||
| t.Fatalf("unable to dial remote side: %s", err) | |||||
| } | |||||
| } | |||||
| func TestAuthMethodFallback(t *testing.T) { | |||||
| var passwordCalled bool | |||||
| config := &ClientConfig{ | |||||
| User: "testuser", | |||||
| Auth: []AuthMethod{ | |||||
| PublicKeys(testSigners["rsa"]), | |||||
| PasswordCallback( | |||||
| func() (string, error) { | |||||
| passwordCalled = true | |||||
| return "WRONG", nil | |||||
| }), | |||||
| }, | |||||
| } | |||||
| if err := tryAuth(t, config); err != nil { | |||||
| t.Fatalf("unable to dial remote side: %s", err) | |||||
| } | |||||
| if passwordCalled { | |||||
| t.Errorf("password auth tried before public-key auth.") | |||||
| } | |||||
| } | |||||
| func TestAuthMethodWrongPassword(t *testing.T) { | |||||
| config := &ClientConfig{ | |||||
| User: "testuser", | |||||
| Auth: []AuthMethod{ | |||||
| Password("wrong"), | |||||
| PublicKeys(testSigners["rsa"]), | |||||
| }, | |||||
| } | |||||
| if err := tryAuth(t, config); err != nil { | |||||
| t.Fatalf("unable to dial remote side: %s", err) | |||||
| } | |||||
| } | |||||
| func TestAuthMethodKeyboardInteractive(t *testing.T) { | |||||
| answers := keyboardInteractive(map[string]string{ | |||||
| "question1": "answer1", | |||||
| "question2": "answer2", | |||||
| }) | |||||
| config := &ClientConfig{ | |||||
| User: "testuser", | |||||
| Auth: []AuthMethod{ | |||||
| KeyboardInteractive(answers.Challenge), | |||||
| }, | |||||
| } | |||||
| if err := tryAuth(t, config); err != nil { | |||||
| t.Fatalf("unable to dial remote side: %s", err) | |||||
| } | |||||
| } | |||||
| func TestAuthMethodWrongKeyboardInteractive(t *testing.T) { | |||||
| answers := keyboardInteractive(map[string]string{ | |||||
| "question1": "answer1", | |||||
| "question2": "WRONG", | |||||
| }) | |||||
| config := &ClientConfig{ | |||||
| User: "testuser", | |||||
| Auth: []AuthMethod{ | |||||
| KeyboardInteractive(answers.Challenge), | |||||
| }, | |||||
| } | |||||
| if err := tryAuth(t, config); err == nil { | |||||
| t.Fatalf("wrong answers should not have authenticated with KeyboardInteractive") | |||||
| } | |||||
| } | |||||
| // the mock server will only authenticate ssh-rsa keys | |||||
| func TestAuthMethodInvalidPublicKey(t *testing.T) { | |||||
| config := &ClientConfig{ | |||||
| User: "testuser", | |||||
| Auth: []AuthMethod{ | |||||
| PublicKeys(testSigners["dsa"]), | |||||
| }, | |||||
| } | |||||
| if err := tryAuth(t, config); err == nil { | |||||
| t.Fatalf("dsa private key should not have authenticated with rsa public key") | |||||
| } | |||||
| } | |||||
| // the client should authenticate with the second key | |||||
| func TestAuthMethodRSAandDSA(t *testing.T) { | |||||
| config := &ClientConfig{ | |||||
| User: "testuser", | |||||
| Auth: []AuthMethod{ | |||||
| PublicKeys(testSigners["dsa"], testSigners["rsa"]), | |||||
| }, | |||||
| } | |||||
| if err := tryAuth(t, config); err != nil { | |||||
| t.Fatalf("client could not authenticate with rsa key: %v", err) | |||||
| } | |||||
| } | |||||
| func TestClientHMAC(t *testing.T) { | |||||
| for _, mac := range supportedMACs { | |||||
| config := &ClientConfig{ | |||||
| User: "testuser", | |||||
| Auth: []AuthMethod{ | |||||
| PublicKeys(testSigners["rsa"]), | |||||
| }, | |||||
| Config: Config{ | |||||
| MACs: []string{mac}, | |||||
| }, | |||||
| } | |||||
| if err := tryAuth(t, config); err != nil { | |||||
| t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err) | |||||
| } | |||||
| } | |||||
| } | |||||
| // issue 4285. | |||||
| func TestClientUnsupportedCipher(t *testing.T) { | |||||
| config := &ClientConfig{ | |||||
| User: "testuser", | |||||
| Auth: []AuthMethod{ | |||||
| PublicKeys(), | |||||
| }, | |||||
| Config: Config{ | |||||
| Ciphers: []string{"aes128-cbc"}, // not currently supported | |||||
| }, | |||||
| } | |||||
| if err := tryAuth(t, config); err == nil { | |||||
| t.Errorf("expected no ciphers in common") | |||||
| } | |||||
| } | |||||
| func TestClientUnsupportedKex(t *testing.T) { | |||||
| config := &ClientConfig{ | |||||
| User: "testuser", | |||||
| Auth: []AuthMethod{ | |||||
| PublicKeys(), | |||||
| }, | |||||
| Config: Config{ | |||||
| KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported | |||||
| }, | |||||
| } | |||||
| if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "common algorithm") { | |||||
| t.Errorf("got %v, expected 'common algorithm'", err) | |||||
| } | |||||
| } | |||||
| func TestClientLoginCert(t *testing.T) { | |||||
| cert := &Certificate{ | |||||
| Key: testPublicKeys["rsa"], | |||||
| ValidBefore: CertTimeInfinity, | |||||
| CertType: UserCert, | |||||
| } | |||||
| cert.SignCert(rand.Reader, testSigners["ecdsa"]) | |||||
| certSigner, err := NewCertSigner(cert, testSigners["rsa"]) | |||||
| if err != nil { | |||||
| t.Fatalf("NewCertSigner: %v", err) | |||||
| } | |||||
| clientConfig := &ClientConfig{ | |||||
| User: "user", | |||||
| } | |||||
| clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner)) | |||||
| t.Log("should succeed") | |||||
| if err := tryAuth(t, clientConfig); err != nil { | |||||
| t.Errorf("cert login failed: %v", err) | |||||
| } | |||||
| t.Log("corrupted signature") | |||||
| cert.Signature.Blob[0]++ | |||||
| if err := tryAuth(t, clientConfig); err == nil { | |||||
| t.Errorf("cert login passed with corrupted sig") | |||||
| } | |||||
| t.Log("revoked") | |||||
| cert.Serial = 666 | |||||
| cert.SignCert(rand.Reader, testSigners["ecdsa"]) | |||||
| if err := tryAuth(t, clientConfig); err == nil { | |||||
| t.Errorf("revoked cert login succeeded") | |||||
| } | |||||
| cert.Serial = 1 | |||||
| t.Log("sign with wrong key") | |||||
| cert.SignCert(rand.Reader, testSigners["dsa"]) | |||||
| if err := tryAuth(t, clientConfig); err == nil { | |||||
| t.Errorf("cert login passed with non-authoritive key") | |||||
| } | |||||
| t.Log("host cert") | |||||
| cert.CertType = HostCert | |||||
| cert.SignCert(rand.Reader, testSigners["ecdsa"]) | |||||
| if err := tryAuth(t, clientConfig); err == nil { | |||||
| t.Errorf("cert login passed with wrong type") | |||||
| } | |||||
| cert.CertType = UserCert | |||||
| t.Log("principal specified") | |||||
| cert.ValidPrincipals = []string{"user"} | |||||
| cert.SignCert(rand.Reader, testSigners["ecdsa"]) | |||||
| if err := tryAuth(t, clientConfig); err != nil { | |||||
| t.Errorf("cert login failed: %v", err) | |||||
| } | |||||
| t.Log("wrong principal specified") | |||||
| cert.ValidPrincipals = []string{"fred"} | |||||
| cert.SignCert(rand.Reader, testSigners["ecdsa"]) | |||||
| if err := tryAuth(t, clientConfig); err == nil { | |||||
| t.Errorf("cert login passed with wrong principal") | |||||
| } | |||||
| cert.ValidPrincipals = nil | |||||
| t.Log("added critical option") | |||||
| cert.CriticalOptions = map[string]string{"root-access": "yes"} | |||||
| cert.SignCert(rand.Reader, testSigners["ecdsa"]) | |||||
| if err := tryAuth(t, clientConfig); err == nil { | |||||
| t.Errorf("cert login passed with unrecognized critical option") | |||||
| } | |||||
| t.Log("allowed source address") | |||||
| cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42/24"} | |||||
| cert.SignCert(rand.Reader, testSigners["ecdsa"]) | |||||
| if err := tryAuth(t, clientConfig); err != nil { | |||||
| t.Errorf("cert login with source-address failed: %v", err) | |||||
| } | |||||
| t.Log("disallowed source address") | |||||
| cert.CriticalOptions = map[string]string{"source-address": "127.0.0.42"} | |||||
| cert.SignCert(rand.Reader, testSigners["ecdsa"]) | |||||
| if err := tryAuth(t, clientConfig); err == nil { | |||||
| t.Errorf("cert login with source-address succeeded") | |||||
| } | |||||
| } | |||||
| func testPermissionsPassing(withPermissions bool, t *testing.T) { | |||||
| serverConfig := &ServerConfig{ | |||||
| PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { | |||||
| if conn.User() == "nopermissions" { | |||||
| return nil, nil | |||||
| } else { | |||||
| return &Permissions{}, nil | |||||
| } | |||||
| }, | |||||
| } | |||||
| serverConfig.AddHostKey(testSigners["rsa"]) | |||||
| clientConfig := &ClientConfig{ | |||||
| Auth: []AuthMethod{ | |||||
| PublicKeys(testSigners["rsa"]), | |||||
| }, | |||||
| } | |||||
| if withPermissions { | |||||
| clientConfig.User = "permissions" | |||||
| } else { | |||||
| clientConfig.User = "nopermissions" | |||||
| } | |||||
| c1, c2, err := netPipe() | |||||
| if err != nil { | |||||
| t.Fatalf("netPipe: %v", err) | |||||
| } | |||||
| defer c1.Close() | |||||
| defer c2.Close() | |||||
| go NewClientConn(c2, "", clientConfig) | |||||
| serverConn, err := newServer(c1, serverConfig) | |||||
| if err != nil { | |||||
| t.Fatal(err) | |||||
| } | |||||
| if p := serverConn.Permissions; (p != nil) != withPermissions { | |||||
| t.Fatalf("withPermissions is %t, but Permissions object is %#v", withPermissions, p) | |||||
| } | |||||
| } | |||||
| func TestPermissionsPassing(t *testing.T) { | |||||
| testPermissionsPassing(true, t) | |||||
| } | |||||
| func TestNoPermissionsPassing(t *testing.T) { | |||||
| testPermissionsPassing(false, t) | |||||
| } | |||||
| @@ -1,39 +0,0 @@ | |||||
| // Copyright 2014 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "net" | |||||
| "testing" | |||||
| ) | |||||
| func testClientVersion(t *testing.T, config *ClientConfig, expected string) { | |||||
| clientConn, serverConn := net.Pipe() | |||||
| defer clientConn.Close() | |||||
| receivedVersion := make(chan string, 1) | |||||
| go func() { | |||||
| version, err := readVersion(serverConn) | |||||
| if err != nil { | |||||
| receivedVersion <- "" | |||||
| } else { | |||||
| receivedVersion <- string(version) | |||||
| } | |||||
| serverConn.Close() | |||||
| }() | |||||
| NewClientConn(clientConn, "", config) | |||||
| actual := <-receivedVersion | |||||
| if actual != expected { | |||||
| t.Fatalf("got %s; want %s", actual, expected) | |||||
| } | |||||
| } | |||||
| func TestCustomClientVersion(t *testing.T) { | |||||
| version := "Test-Client-Version-0.0" | |||||
| testClientVersion(t, &ClientConfig{ClientVersion: version}, version) | |||||
| } | |||||
| func TestDefaultClientVersion(t *testing.T) { | |||||
| testClientVersion(t, &ClientConfig{}, packageVersion) | |||||
| } | |||||
| @@ -1,354 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "crypto" | |||||
| "crypto/rand" | |||||
| "fmt" | |||||
| "io" | |||||
| "sync" | |||||
| _ "crypto/sha1" | |||||
| _ "crypto/sha256" | |||||
| _ "crypto/sha512" | |||||
| ) | |||||
| // These are string constants in the SSH protocol. | |||||
| const ( | |||||
| compressionNone = "none" | |||||
| serviceUserAuth = "ssh-userauth" | |||||
| serviceSSH = "ssh-connection" | |||||
| ) | |||||
| // supportedCiphers specifies the supported ciphers in preference order. | |||||
| var supportedCiphers = []string{ | |||||
| "aes128-ctr", "aes192-ctr", "aes256-ctr", | |||||
| "aes128-gcm@openssh.com", | |||||
| "arcfour256", "arcfour128", | |||||
| } | |||||
| // supportedKexAlgos specifies the supported key-exchange algorithms in | |||||
| // preference order. | |||||
| var supportedKexAlgos = []string{ | |||||
| kexAlgoCurve25519SHA256, | |||||
| // P384 and P521 are not constant-time yet, but since we don't | |||||
| // reuse ephemeral keys, using them for ECDH should be OK. | |||||
| kexAlgoECDH256, kexAlgoECDH384, kexAlgoECDH521, | |||||
| kexAlgoDH14SHA1, kexAlgoDH1SHA1, | |||||
| } | |||||
| // supportedKexAlgos specifies the supported host-key algorithms (i.e. methods | |||||
| // of authenticating servers) in preference order. | |||||
| var supportedHostKeyAlgos = []string{ | |||||
| CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, | |||||
| CertAlgoECDSA384v01, CertAlgoECDSA521v01, | |||||
| KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, | |||||
| KeyAlgoRSA, KeyAlgoDSA, | |||||
| } | |||||
| // supportedMACs specifies a default set of MAC algorithms in preference order. | |||||
| // This is based on RFC 4253, section 6.4, but with hmac-md5 variants removed | |||||
| // because they have reached the end of their useful life. | |||||
| var supportedMACs = []string{ | |||||
| "hmac-sha2-256", "hmac-sha1", "hmac-sha1-96", | |||||
| } | |||||
| var supportedCompressions = []string{compressionNone} | |||||
| // hashFuncs keeps the mapping of supported algorithms to their respective | |||||
| // hashes needed for signature verification. | |||||
| var hashFuncs = map[string]crypto.Hash{ | |||||
| KeyAlgoRSA: crypto.SHA1, | |||||
| KeyAlgoDSA: crypto.SHA1, | |||||
| KeyAlgoECDSA256: crypto.SHA256, | |||||
| KeyAlgoECDSA384: crypto.SHA384, | |||||
| KeyAlgoECDSA521: crypto.SHA512, | |||||
| CertAlgoRSAv01: crypto.SHA1, | |||||
| CertAlgoDSAv01: crypto.SHA1, | |||||
| CertAlgoECDSA256v01: crypto.SHA256, | |||||
| CertAlgoECDSA384v01: crypto.SHA384, | |||||
| CertAlgoECDSA521v01: crypto.SHA512, | |||||
| } | |||||
| // unexpectedMessageError results when the SSH message that we received didn't | |||||
| // match what we wanted. | |||||
| func unexpectedMessageError(expected, got uint8) error { | |||||
| return fmt.Errorf("ssh: unexpected message type %d (expected %d)", got, expected) | |||||
| } | |||||
| // parseError results from a malformed SSH message. | |||||
| func parseError(tag uint8) error { | |||||
| return fmt.Errorf("ssh: parse error in message type %d", tag) | |||||
| } | |||||
| func findCommon(what string, client []string, server []string) (common string, err error) { | |||||
| for _, c := range client { | |||||
| for _, s := range server { | |||||
| if c == s { | |||||
| return c, nil | |||||
| } | |||||
| } | |||||
| } | |||||
| return "", fmt.Errorf("ssh: no common algorithm for %s; client offered: %v, server offered: %v", what, client, server) | |||||
| } | |||||
| type directionAlgorithms struct { | |||||
| Cipher string | |||||
| MAC string | |||||
| Compression string | |||||
| } | |||||
| type algorithms struct { | |||||
| kex string | |||||
| hostKey string | |||||
| w directionAlgorithms | |||||
| r directionAlgorithms | |||||
| } | |||||
| func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms, err error) { | |||||
| result := &algorithms{} | |||||
| result.kex, err = findCommon("key exchange", clientKexInit.KexAlgos, serverKexInit.KexAlgos) | |||||
| if err != nil { | |||||
| return | |||||
| } | |||||
| result.hostKey, err = findCommon("host key", clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos) | |||||
| if err != nil { | |||||
| return | |||||
| } | |||||
| result.w.Cipher, err = findCommon("client to server cipher", clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer) | |||||
| if err != nil { | |||||
| return | |||||
| } | |||||
| result.r.Cipher, err = findCommon("server to client cipher", clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient) | |||||
| if err != nil { | |||||
| return | |||||
| } | |||||
| result.w.MAC, err = findCommon("client to server MAC", clientKexInit.MACsClientServer, serverKexInit.MACsClientServer) | |||||
| if err != nil { | |||||
| return | |||||
| } | |||||
| result.r.MAC, err = findCommon("server to client MAC", clientKexInit.MACsServerClient, serverKexInit.MACsServerClient) | |||||
| if err != nil { | |||||
| return | |||||
| } | |||||
| result.w.Compression, err = findCommon("client to server compression", clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer) | |||||
| if err != nil { | |||||
| return | |||||
| } | |||||
| result.r.Compression, err = findCommon("server to client compression", clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient) | |||||
| if err != nil { | |||||
| return | |||||
| } | |||||
| return result, nil | |||||
| } | |||||
| // If rekeythreshold is too small, we can't make any progress sending | |||||
| // stuff. | |||||
| const minRekeyThreshold uint64 = 256 | |||||
| // Config contains configuration data common to both ServerConfig and | |||||
| // ClientConfig. | |||||
| type Config struct { | |||||
| // Rand provides the source of entropy for cryptographic | |||||
| // primitives. If Rand is nil, the cryptographic random reader | |||||
| // in package crypto/rand will be used. | |||||
| Rand io.Reader | |||||
| // The maximum number of bytes sent or received after which a | |||||
| // new key is negotiated. It must be at least 256. If | |||||
| // unspecified, 1 gigabyte is used. | |||||
| RekeyThreshold uint64 | |||||
| // The allowed key exchanges algorithms. If unspecified then a | |||||
| // default set of algorithms is used. | |||||
| KeyExchanges []string | |||||
| // The allowed cipher algorithms. If unspecified then a sensible | |||||
| // default is used. | |||||
| Ciphers []string | |||||
| // The allowed MAC algorithms. If unspecified then a sensible default | |||||
| // is used. | |||||
| MACs []string | |||||
| } | |||||
| // SetDefaults sets sensible values for unset fields in config. This is | |||||
| // exported for testing: Configs passed to SSH functions are copied and have | |||||
| // default values set automatically. | |||||
| func (c *Config) SetDefaults() { | |||||
| if c.Rand == nil { | |||||
| c.Rand = rand.Reader | |||||
| } | |||||
| if c.Ciphers == nil { | |||||
| c.Ciphers = supportedCiphers | |||||
| } | |||||
| var ciphers []string | |||||
| for _, c := range c.Ciphers { | |||||
| if cipherModes[c] != nil { | |||||
| // reject the cipher if we have no cipherModes definition | |||||
| ciphers = append(ciphers, c) | |||||
| } | |||||
| } | |||||
| c.Ciphers = ciphers | |||||
| if c.KeyExchanges == nil { | |||||
| c.KeyExchanges = supportedKexAlgos | |||||
| } | |||||
| if c.MACs == nil { | |||||
| c.MACs = supportedMACs | |||||
| } | |||||
| if c.RekeyThreshold == 0 { | |||||
| // RFC 4253, section 9 suggests rekeying after 1G. | |||||
| c.RekeyThreshold = 1 << 30 | |||||
| } | |||||
| if c.RekeyThreshold < minRekeyThreshold { | |||||
| c.RekeyThreshold = minRekeyThreshold | |||||
| } | |||||
| } | |||||
| // buildDataSignedForAuth returns the data that is signed in order to prove | |||||
| // possession of a private key. See RFC 4252, section 7. | |||||
| func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte { | |||||
| data := struct { | |||||
| Session []byte | |||||
| Type byte | |||||
| User string | |||||
| Service string | |||||
| Method string | |||||
| Sign bool | |||||
| Algo []byte | |||||
| PubKey []byte | |||||
| }{ | |||||
| sessionId, | |||||
| msgUserAuthRequest, | |||||
| req.User, | |||||
| req.Service, | |||||
| req.Method, | |||||
| true, | |||||
| algo, | |||||
| pubKey, | |||||
| } | |||||
| return Marshal(data) | |||||
| } | |||||
| func appendU16(buf []byte, n uint16) []byte { | |||||
| return append(buf, byte(n>>8), byte(n)) | |||||
| } | |||||
| func appendU32(buf []byte, n uint32) []byte { | |||||
| return append(buf, byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) | |||||
| } | |||||
| func appendU64(buf []byte, n uint64) []byte { | |||||
| return append(buf, | |||||
| byte(n>>56), byte(n>>48), byte(n>>40), byte(n>>32), | |||||
| byte(n>>24), byte(n>>16), byte(n>>8), byte(n)) | |||||
| } | |||||
| func appendInt(buf []byte, n int) []byte { | |||||
| return appendU32(buf, uint32(n)) | |||||
| } | |||||
| func appendString(buf []byte, s string) []byte { | |||||
| buf = appendU32(buf, uint32(len(s))) | |||||
| buf = append(buf, s...) | |||||
| return buf | |||||
| } | |||||
| func appendBool(buf []byte, b bool) []byte { | |||||
| if b { | |||||
| return append(buf, 1) | |||||
| } | |||||
| return append(buf, 0) | |||||
| } | |||||
| // newCond is a helper to hide the fact that there is no usable zero | |||||
| // value for sync.Cond. | |||||
| func newCond() *sync.Cond { return sync.NewCond(new(sync.Mutex)) } | |||||
| // window represents the buffer available to clients | |||||
| // wishing to write to a channel. | |||||
| type window struct { | |||||
| *sync.Cond | |||||
| win uint32 // RFC 4254 5.2 says the window size can grow to 2^32-1 | |||||
| writeWaiters int | |||||
| closed bool | |||||
| } | |||||
| // add adds win to the amount of window available | |||||
| // for consumers. | |||||
| func (w *window) add(win uint32) bool { | |||||
| // a zero sized window adjust is a noop. | |||||
| if win == 0 { | |||||
| return true | |||||
| } | |||||
| w.L.Lock() | |||||
| if w.win+win < win { | |||||
| w.L.Unlock() | |||||
| return false | |||||
| } | |||||
| w.win += win | |||||
| // It is unusual that multiple goroutines would be attempting to reserve | |||||
| // window space, but not guaranteed. Use broadcast to notify all waiters | |||||
| // that additional window is available. | |||||
| w.Broadcast() | |||||
| w.L.Unlock() | |||||
| return true | |||||
| } | |||||
| // close sets the window to closed, so all reservations fail | |||||
| // immediately. | |||||
| func (w *window) close() { | |||||
| w.L.Lock() | |||||
| w.closed = true | |||||
| w.Broadcast() | |||||
| w.L.Unlock() | |||||
| } | |||||
| // reserve reserves win from the available window capacity. | |||||
| // If no capacity remains, reserve will block. reserve may | |||||
| // return less than requested. | |||||
| func (w *window) reserve(win uint32) (uint32, error) { | |||||
| var err error | |||||
| w.L.Lock() | |||||
| w.writeWaiters++ | |||||
| w.Broadcast() | |||||
| for w.win == 0 && !w.closed { | |||||
| w.Wait() | |||||
| } | |||||
| w.writeWaiters-- | |||||
| if w.win < win { | |||||
| win = w.win | |||||
| } | |||||
| w.win -= win | |||||
| if w.closed { | |||||
| err = io.EOF | |||||
| } | |||||
| w.L.Unlock() | |||||
| return win, err | |||||
| } | |||||
| // waitWriterBlocked waits until some goroutine is blocked for further | |||||
| // writes. It is used in tests only. | |||||
| func (w *window) waitWriterBlocked() { | |||||
| w.Cond.L.Lock() | |||||
| for w.writeWaiters == 0 { | |||||
| w.Cond.Wait() | |||||
| } | |||||
| w.Cond.L.Unlock() | |||||
| } | |||||
| @@ -1,144 +0,0 @@ | |||||
| // Copyright 2013 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "fmt" | |||||
| "net" | |||||
| ) | |||||
| // OpenChannelError is returned if the other side rejects an | |||||
| // OpenChannel request. | |||||
| type OpenChannelError struct { | |||||
| Reason RejectionReason | |||||
| Message string | |||||
| } | |||||
| func (e *OpenChannelError) Error() string { | |||||
| return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) | |||||
| } | |||||
| // ConnMetadata holds metadata for the connection. | |||||
| type ConnMetadata interface { | |||||
| // User returns the user ID for this connection. | |||||
| // It is empty if no authentication is used. | |||||
| User() string | |||||
| // SessionID returns the sesson hash, also denoted by H. | |||||
| SessionID() []byte | |||||
| // ClientVersion returns the client's version string as hashed | |||||
| // into the session ID. | |||||
| ClientVersion() []byte | |||||
| // ServerVersion returns the server's version string as hashed | |||||
| // into the session ID. | |||||
| ServerVersion() []byte | |||||
| // RemoteAddr returns the remote address for this connection. | |||||
| RemoteAddr() net.Addr | |||||
| // LocalAddr returns the local address for this connection. | |||||
| LocalAddr() net.Addr | |||||
| } | |||||
| // Conn represents an SSH connection for both server and client roles. | |||||
| // Conn is the basis for implementing an application layer, such | |||||
| // as ClientConn, which implements the traditional shell access for | |||||
| // clients. | |||||
| type Conn interface { | |||||
| ConnMetadata | |||||
| // SendRequest sends a global request, and returns the | |||||
| // reply. If wantReply is true, it returns the response status | |||||
| // and payload. See also RFC4254, section 4. | |||||
| SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) | |||||
| // OpenChannel tries to open an channel. If the request is | |||||
| // rejected, it returns *OpenChannelError. On success it returns | |||||
| // the SSH Channel and a Go channel for incoming, out-of-band | |||||
| // requests. The Go channel must be serviced, or the | |||||
| // connection will hang. | |||||
| OpenChannel(name string, data []byte) (Channel, <-chan *Request, error) | |||||
| // Close closes the underlying network connection | |||||
| Close() error | |||||
| // Wait blocks until the connection has shut down, and returns the | |||||
| // error causing the shutdown. | |||||
| Wait() error | |||||
| // TODO(hanwen): consider exposing: | |||||
| // RequestKeyChange | |||||
| // Disconnect | |||||
| } | |||||
| // DiscardRequests consumes and rejects all requests from the | |||||
| // passed-in channel. | |||||
| func DiscardRequests(in <-chan *Request) { | |||||
| for req := range in { | |||||
| if req.WantReply { | |||||
| req.Reply(false, nil) | |||||
| } | |||||
| } | |||||
| } | |||||
| // A connection represents an incoming connection. | |||||
| type connection struct { | |||||
| transport *handshakeTransport | |||||
| sshConn | |||||
| // The connection protocol. | |||||
| *mux | |||||
| } | |||||
| func (c *connection) Close() error { | |||||
| return c.sshConn.conn.Close() | |||||
| } | |||||
| // sshconn provides net.Conn metadata, but disallows direct reads and | |||||
| // writes. | |||||
| type sshConn struct { | |||||
| conn net.Conn | |||||
| user string | |||||
| sessionID []byte | |||||
| clientVersion []byte | |||||
| serverVersion []byte | |||||
| } | |||||
| func dup(src []byte) []byte { | |||||
| dst := make([]byte, len(src)) | |||||
| copy(dst, src) | |||||
| return dst | |||||
| } | |||||
| func (c *sshConn) User() string { | |||||
| return c.user | |||||
| } | |||||
| func (c *sshConn) RemoteAddr() net.Addr { | |||||
| return c.conn.RemoteAddr() | |||||
| } | |||||
| func (c *sshConn) Close() error { | |||||
| return c.conn.Close() | |||||
| } | |||||
| func (c *sshConn) LocalAddr() net.Addr { | |||||
| return c.conn.LocalAddr() | |||||
| } | |||||
| func (c *sshConn) SessionID() []byte { | |||||
| return dup(c.sessionID) | |||||
| } | |||||
| func (c *sshConn) ClientVersion() []byte { | |||||
| return dup(c.clientVersion) | |||||
| } | |||||
| func (c *sshConn) ServerVersion() []byte { | |||||
| return dup(c.serverVersion) | |||||
| } | |||||
| @@ -1,18 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| /* | |||||
| Package ssh implements an SSH client and server. | |||||
| SSH is a transport security protocol, an authentication protocol and a | |||||
| family of application protocols. The most typical application level | |||||
| protocol is a remote shell and this is specifically implemented. However, | |||||
| the multiplexed nature of SSH is exposed to users that wish to support | |||||
| others. | |||||
| References: | |||||
| [PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD | |||||
| [SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1 | |||||
| */ | |||||
| package ssh | |||||
| @@ -1,211 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh_test | |||||
| import ( | |||||
| "bytes" | |||||
| "fmt" | |||||
| "io/ioutil" | |||||
| "log" | |||||
| "net" | |||||
| "net/http" | |||||
| "github.com/gogits/gogs/modules/crypto/ssh" | |||||
| "github.com/gogits/gogs/modules/crypto/ssh/terminal" | |||||
| ) | |||||
| func ExampleNewServerConn() { | |||||
| // An SSH server is represented by a ServerConfig, which holds | |||||
| // certificate details and handles authentication of ServerConns. | |||||
| config := &ssh.ServerConfig{ | |||||
| PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { | |||||
| // Should use constant-time compare (or better, salt+hash) in | |||||
| // a production setting. | |||||
| if c.User() == "testuser" && string(pass) == "tiger" { | |||||
| return nil, nil | |||||
| } | |||||
| return nil, fmt.Errorf("password rejected for %q", c.User()) | |||||
| }, | |||||
| } | |||||
| privateBytes, err := ioutil.ReadFile("id_rsa") | |||||
| if err != nil { | |||||
| panic("Failed to load private key") | |||||
| } | |||||
| private, err := ssh.ParsePrivateKey(privateBytes) | |||||
| if err != nil { | |||||
| panic("Failed to parse private key") | |||||
| } | |||||
| config.AddHostKey(private) | |||||
| // Once a ServerConfig has been configured, connections can be | |||||
| // accepted. | |||||
| listener, err := net.Listen("tcp", "0.0.0.0:2022") | |||||
| if err != nil { | |||||
| panic("failed to listen for connection") | |||||
| } | |||||
| nConn, err := listener.Accept() | |||||
| if err != nil { | |||||
| panic("failed to accept incoming connection") | |||||
| } | |||||
| // Before use, a handshake must be performed on the incoming | |||||
| // net.Conn. | |||||
| _, chans, reqs, err := ssh.NewServerConn(nConn, config) | |||||
| if err != nil { | |||||
| panic("failed to handshake") | |||||
| } | |||||
| // The incoming Request channel must be serviced. | |||||
| go ssh.DiscardRequests(reqs) | |||||
| // Service the incoming Channel channel. | |||||
| for newChannel := range chans { | |||||
| // Channels have a type, depending on the application level | |||||
| // protocol intended. In the case of a shell, the type is | |||||
| // "session" and ServerShell may be used to present a simple | |||||
| // terminal interface. | |||||
| if newChannel.ChannelType() != "session" { | |||||
| newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") | |||||
| continue | |||||
| } | |||||
| channel, requests, err := newChannel.Accept() | |||||
| if err != nil { | |||||
| panic("could not accept channel.") | |||||
| } | |||||
| // Sessions have out-of-band requests such as "shell", | |||||
| // "pty-req" and "env". Here we handle only the | |||||
| // "shell" request. | |||||
| go func(in <-chan *ssh.Request) { | |||||
| for req := range in { | |||||
| ok := false | |||||
| switch req.Type { | |||||
| case "shell": | |||||
| ok = true | |||||
| if len(req.Payload) > 0 { | |||||
| // We don't accept any | |||||
| // commands, only the | |||||
| // default shell. | |||||
| ok = false | |||||
| } | |||||
| } | |||||
| req.Reply(ok, nil) | |||||
| } | |||||
| }(requests) | |||||
| term := terminal.NewTerminal(channel, "> ") | |||||
| go func() { | |||||
| defer channel.Close() | |||||
| for { | |||||
| line, err := term.ReadLine() | |||||
| if err != nil { | |||||
| break | |||||
| } | |||||
| fmt.Println(line) | |||||
| } | |||||
| }() | |||||
| } | |||||
| } | |||||
| func ExampleDial() { | |||||
| // An SSH client is represented with a ClientConn. Currently only | |||||
| // the "password" authentication method is supported. | |||||
| // | |||||
| // To authenticate with the remote server you must pass at least one | |||||
| // implementation of AuthMethod via the Auth field in ClientConfig. | |||||
| config := &ssh.ClientConfig{ | |||||
| User: "username", | |||||
| Auth: []ssh.AuthMethod{ | |||||
| ssh.Password("yourpassword"), | |||||
| }, | |||||
| } | |||||
| client, err := ssh.Dial("tcp", "yourserver.com:22", config) | |||||
| if err != nil { | |||||
| panic("Failed to dial: " + err.Error()) | |||||
| } | |||||
| // Each ClientConn can support multiple interactive sessions, | |||||
| // represented by a Session. | |||||
| session, err := client.NewSession() | |||||
| if err != nil { | |||||
| panic("Failed to create session: " + err.Error()) | |||||
| } | |||||
| defer session.Close() | |||||
| // Once a Session is created, you can execute a single command on | |||||
| // the remote side using the Run method. | |||||
| var b bytes.Buffer | |||||
| session.Stdout = &b | |||||
| if err := session.Run("/usr/bin/whoami"); err != nil { | |||||
| panic("Failed to run: " + err.Error()) | |||||
| } | |||||
| fmt.Println(b.String()) | |||||
| } | |||||
| func ExampleClient_Listen() { | |||||
| config := &ssh.ClientConfig{ | |||||
| User: "username", | |||||
| Auth: []ssh.AuthMethod{ | |||||
| ssh.Password("password"), | |||||
| }, | |||||
| } | |||||
| // Dial your ssh server. | |||||
| conn, err := ssh.Dial("tcp", "localhost:22", config) | |||||
| if err != nil { | |||||
| log.Fatalf("unable to connect: %s", err) | |||||
| } | |||||
| defer conn.Close() | |||||
| // Request the remote side to open port 8080 on all interfaces. | |||||
| l, err := conn.Listen("tcp", "0.0.0.0:8080") | |||||
| if err != nil { | |||||
| log.Fatalf("unable to register tcp forward: %v", err) | |||||
| } | |||||
| defer l.Close() | |||||
| // Serve HTTP with your SSH server acting as a reverse proxy. | |||||
| http.Serve(l, http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { | |||||
| fmt.Fprintf(resp, "Hello world!\n") | |||||
| })) | |||||
| } | |||||
| func ExampleSession_RequestPty() { | |||||
| // Create client config | |||||
| config := &ssh.ClientConfig{ | |||||
| User: "username", | |||||
| Auth: []ssh.AuthMethod{ | |||||
| ssh.Password("password"), | |||||
| }, | |||||
| } | |||||
| // Connect to ssh server | |||||
| conn, err := ssh.Dial("tcp", "localhost:22", config) | |||||
| if err != nil { | |||||
| log.Fatalf("unable to connect: %s", err) | |||||
| } | |||||
| defer conn.Close() | |||||
| // Create a session | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| log.Fatalf("unable to create session: %s", err) | |||||
| } | |||||
| defer session.Close() | |||||
| // Set up terminal modes | |||||
| modes := ssh.TerminalModes{ | |||||
| ssh.ECHO: 0, // disable echoing | |||||
| ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud | |||||
| ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud | |||||
| } | |||||
| // Request pseudo terminal | |||||
| if err := session.RequestPty("xterm", 80, 40, modes); err != nil { | |||||
| log.Fatalf("request for pseudo terminal failed: %s", err) | |||||
| } | |||||
| // Start remote shell | |||||
| if err := session.Shell(); err != nil { | |||||
| log.Fatalf("failed to start shell: %s", err) | |||||
| } | |||||
| } | |||||
| @@ -1,412 +0,0 @@ | |||||
| // Copyright 2013 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "crypto/rand" | |||||
| "errors" | |||||
| "fmt" | |||||
| "io" | |||||
| "log" | |||||
| "net" | |||||
| "sync" | |||||
| ) | |||||
| // debugHandshake, if set, prints messages sent and received. Key | |||||
| // exchange messages are printed as if DH were used, so the debug | |||||
| // messages are wrong when using ECDH. | |||||
| const debugHandshake = false | |||||
| // keyingTransport is a packet based transport that supports key | |||||
| // changes. It need not be thread-safe. It should pass through | |||||
| // msgNewKeys in both directions. | |||||
| type keyingTransport interface { | |||||
| packetConn | |||||
| // prepareKeyChange sets up a key change. The key change for a | |||||
| // direction will be effected if a msgNewKeys message is sent | |||||
| // or received. | |||||
| prepareKeyChange(*algorithms, *kexResult) error | |||||
| // getSessionID returns the session ID. prepareKeyChange must | |||||
| // have been called once. | |||||
| getSessionID() []byte | |||||
| } | |||||
| // rekeyingTransport is the interface of handshakeTransport that we | |||||
| // (internally) expose to ClientConn and ServerConn. | |||||
| type rekeyingTransport interface { | |||||
| packetConn | |||||
| // requestKeyChange asks the remote side to change keys. All | |||||
| // writes are blocked until the key change succeeds, which is | |||||
| // signaled by reading a msgNewKeys. | |||||
| requestKeyChange() error | |||||
| // getSessionID returns the session ID. This is only valid | |||||
| // after the first key change has completed. | |||||
| getSessionID() []byte | |||||
| } | |||||
| // handshakeTransport implements rekeying on top of a keyingTransport | |||||
| // and offers a thread-safe writePacket() interface. | |||||
| type handshakeTransport struct { | |||||
| conn keyingTransport | |||||
| config *Config | |||||
| serverVersion []byte | |||||
| clientVersion []byte | |||||
| // hostKeys is non-empty if we are the server. In that case, | |||||
| // it contains all host keys that can be used to sign the | |||||
| // connection. | |||||
| hostKeys []Signer | |||||
| // hostKeyAlgorithms is non-empty if we are the client. In that case, | |||||
| // we accept these key types from the server as host key. | |||||
| hostKeyAlgorithms []string | |||||
| // On read error, incoming is closed, and readError is set. | |||||
| incoming chan []byte | |||||
| readError error | |||||
| // data for host key checking | |||||
| hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error | |||||
| dialAddress string | |||||
| remoteAddr net.Addr | |||||
| readSinceKex uint64 | |||||
| // Protects the writing side of the connection | |||||
| mu sync.Mutex | |||||
| cond *sync.Cond | |||||
| sentInitPacket []byte | |||||
| sentInitMsg *kexInitMsg | |||||
| writtenSinceKex uint64 | |||||
| writeError error | |||||
| } | |||||
| func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, serverVersion []byte) *handshakeTransport { | |||||
| t := &handshakeTransport{ | |||||
| conn: conn, | |||||
| serverVersion: serverVersion, | |||||
| clientVersion: clientVersion, | |||||
| incoming: make(chan []byte, 16), | |||||
| config: config, | |||||
| } | |||||
| t.cond = sync.NewCond(&t.mu) | |||||
| return t | |||||
| } | |||||
| func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport { | |||||
| t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) | |||||
| t.dialAddress = dialAddr | |||||
| t.remoteAddr = addr | |||||
| t.hostKeyCallback = config.HostKeyCallback | |||||
| if config.HostKeyAlgorithms != nil { | |||||
| t.hostKeyAlgorithms = config.HostKeyAlgorithms | |||||
| } else { | |||||
| t.hostKeyAlgorithms = supportedHostKeyAlgos | |||||
| } | |||||
| go t.readLoop() | |||||
| return t | |||||
| } | |||||
| func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byte, config *ServerConfig) *handshakeTransport { | |||||
| t := newHandshakeTransport(conn, &config.Config, clientVersion, serverVersion) | |||||
| t.hostKeys = config.hostKeys | |||||
| go t.readLoop() | |||||
| return t | |||||
| } | |||||
| func (t *handshakeTransport) getSessionID() []byte { | |||||
| return t.conn.getSessionID() | |||||
| } | |||||
| func (t *handshakeTransport) id() string { | |||||
| if len(t.hostKeys) > 0 { | |||||
| return "server" | |||||
| } | |||||
| return "client" | |||||
| } | |||||
| func (t *handshakeTransport) readPacket() ([]byte, error) { | |||||
| p, ok := <-t.incoming | |||||
| if !ok { | |||||
| return nil, t.readError | |||||
| } | |||||
| return p, nil | |||||
| } | |||||
| func (t *handshakeTransport) readLoop() { | |||||
| for { | |||||
| p, err := t.readOnePacket() | |||||
| if err != nil { | |||||
| t.readError = err | |||||
| close(t.incoming) | |||||
| break | |||||
| } | |||||
| if p[0] == msgIgnore || p[0] == msgDebug { | |||||
| continue | |||||
| } | |||||
| t.incoming <- p | |||||
| } | |||||
| // If we can't read, declare the writing part dead too. | |||||
| t.mu.Lock() | |||||
| defer t.mu.Unlock() | |||||
| if t.writeError == nil { | |||||
| t.writeError = t.readError | |||||
| } | |||||
| t.cond.Broadcast() | |||||
| } | |||||
| func (t *handshakeTransport) readOnePacket() ([]byte, error) { | |||||
| if t.readSinceKex > t.config.RekeyThreshold { | |||||
| if err := t.requestKeyChange(); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| } | |||||
| p, err := t.conn.readPacket() | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| t.readSinceKex += uint64(len(p)) | |||||
| if debugHandshake { | |||||
| msg, err := decode(p) | |||||
| log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err) | |||||
| } | |||||
| if p[0] != msgKexInit { | |||||
| return p, nil | |||||
| } | |||||
| err = t.enterKeyExchange(p) | |||||
| t.mu.Lock() | |||||
| if err != nil { | |||||
| // drop connection | |||||
| t.conn.Close() | |||||
| t.writeError = err | |||||
| } | |||||
| if debugHandshake { | |||||
| log.Printf("%s exited key exchange, err %v", t.id(), err) | |||||
| } | |||||
| // Unblock writers. | |||||
| t.sentInitMsg = nil | |||||
| t.sentInitPacket = nil | |||||
| t.cond.Broadcast() | |||||
| t.writtenSinceKex = 0 | |||||
| t.mu.Unlock() | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| t.readSinceKex = 0 | |||||
| return []byte{msgNewKeys}, nil | |||||
| } | |||||
| // sendKexInit sends a key change message, and returns the message | |||||
| // that was sent. After initiating the key change, all writes will be | |||||
| // blocked until the change is done, and a failed key change will | |||||
| // close the underlying transport. This function is safe for | |||||
| // concurrent use by multiple goroutines. | |||||
| func (t *handshakeTransport) sendKexInit() (*kexInitMsg, []byte, error) { | |||||
| t.mu.Lock() | |||||
| defer t.mu.Unlock() | |||||
| return t.sendKexInitLocked() | |||||
| } | |||||
| func (t *handshakeTransport) requestKeyChange() error { | |||||
| _, _, err := t.sendKexInit() | |||||
| return err | |||||
| } | |||||
| // sendKexInitLocked sends a key change message. t.mu must be locked | |||||
| // while this happens. | |||||
| func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) { | |||||
| // kexInits may be sent either in response to the other side, | |||||
| // or because our side wants to initiate a key change, so we | |||||
| // may have already sent a kexInit. In that case, don't send a | |||||
| // second kexInit. | |||||
| if t.sentInitMsg != nil { | |||||
| return t.sentInitMsg, t.sentInitPacket, nil | |||||
| } | |||||
| msg := &kexInitMsg{ | |||||
| KexAlgos: t.config.KeyExchanges, | |||||
| CiphersClientServer: t.config.Ciphers, | |||||
| CiphersServerClient: t.config.Ciphers, | |||||
| MACsClientServer: t.config.MACs, | |||||
| MACsServerClient: t.config.MACs, | |||||
| CompressionClientServer: supportedCompressions, | |||||
| CompressionServerClient: supportedCompressions, | |||||
| } | |||||
| io.ReadFull(rand.Reader, msg.Cookie[:]) | |||||
| if len(t.hostKeys) > 0 { | |||||
| for _, k := range t.hostKeys { | |||||
| msg.ServerHostKeyAlgos = append( | |||||
| msg.ServerHostKeyAlgos, k.PublicKey().Type()) | |||||
| } | |||||
| } else { | |||||
| msg.ServerHostKeyAlgos = t.hostKeyAlgorithms | |||||
| } | |||||
| packet := Marshal(msg) | |||||
| // writePacket destroys the contents, so save a copy. | |||||
| packetCopy := make([]byte, len(packet)) | |||||
| copy(packetCopy, packet) | |||||
| if err := t.conn.writePacket(packetCopy); err != nil { | |||||
| return nil, nil, err | |||||
| } | |||||
| t.sentInitMsg = msg | |||||
| t.sentInitPacket = packet | |||||
| return msg, packet, nil | |||||
| } | |||||
| func (t *handshakeTransport) writePacket(p []byte) error { | |||||
| t.mu.Lock() | |||||
| defer t.mu.Unlock() | |||||
| if t.writtenSinceKex > t.config.RekeyThreshold { | |||||
| t.sendKexInitLocked() | |||||
| } | |||||
| for t.sentInitMsg != nil && t.writeError == nil { | |||||
| t.cond.Wait() | |||||
| } | |||||
| if t.writeError != nil { | |||||
| return t.writeError | |||||
| } | |||||
| t.writtenSinceKex += uint64(len(p)) | |||||
| switch p[0] { | |||||
| case msgKexInit: | |||||
| return errors.New("ssh: only handshakeTransport can send kexInit") | |||||
| case msgNewKeys: | |||||
| return errors.New("ssh: only handshakeTransport can send newKeys") | |||||
| default: | |||||
| return t.conn.writePacket(p) | |||||
| } | |||||
| } | |||||
| func (t *handshakeTransport) Close() error { | |||||
| return t.conn.Close() | |||||
| } | |||||
| // enterKeyExchange runs the key exchange. | |||||
| func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { | |||||
| if debugHandshake { | |||||
| log.Printf("%s entered key exchange", t.id()) | |||||
| } | |||||
| myInit, myInitPacket, err := t.sendKexInit() | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| otherInit := &kexInitMsg{} | |||||
| if err := Unmarshal(otherInitPacket, otherInit); err != nil { | |||||
| return err | |||||
| } | |||||
| magics := handshakeMagics{ | |||||
| clientVersion: t.clientVersion, | |||||
| serverVersion: t.serverVersion, | |||||
| clientKexInit: otherInitPacket, | |||||
| serverKexInit: myInitPacket, | |||||
| } | |||||
| clientInit := otherInit | |||||
| serverInit := myInit | |||||
| if len(t.hostKeys) == 0 { | |||||
| clientInit = myInit | |||||
| serverInit = otherInit | |||||
| magics.clientKexInit = myInitPacket | |||||
| magics.serverKexInit = otherInitPacket | |||||
| } | |||||
| algs, err := findAgreedAlgorithms(clientInit, serverInit) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| // We don't send FirstKexFollows, but we handle receiving it. | |||||
| if otherInit.FirstKexFollows && algs.kex != otherInit.KexAlgos[0] { | |||||
| // other side sent a kex message for the wrong algorithm, | |||||
| // which we have to ignore. | |||||
| if _, err := t.conn.readPacket(); err != nil { | |||||
| return err | |||||
| } | |||||
| } | |||||
| kex, ok := kexAlgoMap[algs.kex] | |||||
| if !ok { | |||||
| return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex) | |||||
| } | |||||
| var result *kexResult | |||||
| if len(t.hostKeys) > 0 { | |||||
| result, err = t.server(kex, algs, &magics) | |||||
| } else { | |||||
| result, err = t.client(kex, algs, &magics) | |||||
| } | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| t.conn.prepareKeyChange(algs, result) | |||||
| if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { | |||||
| return err | |||||
| } | |||||
| if packet, err := t.conn.readPacket(); err != nil { | |||||
| return err | |||||
| } else if packet[0] != msgNewKeys { | |||||
| return unexpectedMessageError(msgNewKeys, packet[0]) | |||||
| } | |||||
| return nil | |||||
| } | |||||
| func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { | |||||
| var hostKey Signer | |||||
| for _, k := range t.hostKeys { | |||||
| if algs.hostKey == k.PublicKey().Type() { | |||||
| hostKey = k | |||||
| } | |||||
| } | |||||
| r, err := kex.Server(t.conn, t.config.Rand, magics, hostKey) | |||||
| return r, err | |||||
| } | |||||
| func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *handshakeMagics) (*kexResult, error) { | |||||
| result, err := kex.Client(t.conn, t.config.Rand, magics) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| hostKey, err := ParsePublicKey(result.HostKey) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| if err := verifyHostKeySignature(hostKey, result); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| if t.hostKeyCallback != nil { | |||||
| err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| } | |||||
| return result, nil | |||||
| } | |||||
| @@ -1,415 +0,0 @@ | |||||
| // Copyright 2013 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "bytes" | |||||
| "crypto/rand" | |||||
| "errors" | |||||
| "fmt" | |||||
| "net" | |||||
| "runtime" | |||||
| "strings" | |||||
| "sync" | |||||
| "testing" | |||||
| ) | |||||
| type testChecker struct { | |||||
| calls []string | |||||
| } | |||||
| func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { | |||||
| if dialAddr == "bad" { | |||||
| return fmt.Errorf("dialAddr is bad") | |||||
| } | |||||
| if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil { | |||||
| return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr) | |||||
| } | |||||
| t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal())) | |||||
| return nil | |||||
| } | |||||
| // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and | |||||
| // therefore is buffered (net.Pipe deadlocks if both sides start with | |||||
| // a write.) | |||||
| func netPipe() (net.Conn, net.Conn, error) { | |||||
| listener, err := net.Listen("tcp", "127.0.0.1:0") | |||||
| if err != nil { | |||||
| return nil, nil, err | |||||
| } | |||||
| defer listener.Close() | |||||
| c1, err := net.Dial("tcp", listener.Addr().String()) | |||||
| if err != nil { | |||||
| return nil, nil, err | |||||
| } | |||||
| c2, err := listener.Accept() | |||||
| if err != nil { | |||||
| c1.Close() | |||||
| return nil, nil, err | |||||
| } | |||||
| return c1, c2, nil | |||||
| } | |||||
| func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) { | |||||
| a, b, err := netPipe() | |||||
| if err != nil { | |||||
| return nil, nil, err | |||||
| } | |||||
| trC := newTransport(a, rand.Reader, true) | |||||
| trS := newTransport(b, rand.Reader, false) | |||||
| clientConf.SetDefaults() | |||||
| v := []byte("version") | |||||
| client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr()) | |||||
| serverConf := &ServerConfig{} | |||||
| serverConf.AddHostKey(testSigners["ecdsa"]) | |||||
| serverConf.AddHostKey(testSigners["rsa"]) | |||||
| serverConf.SetDefaults() | |||||
| server = newServerTransport(trS, v, v, serverConf) | |||||
| return client, server, nil | |||||
| } | |||||
| func TestHandshakeBasic(t *testing.T) { | |||||
| if runtime.GOOS == "plan9" { | |||||
| t.Skip("see golang.org/issue/7237") | |||||
| } | |||||
| checker := &testChecker{} | |||||
| trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") | |||||
| if err != nil { | |||||
| t.Fatalf("handshakePair: %v", err) | |||||
| } | |||||
| defer trC.Close() | |||||
| defer trS.Close() | |||||
| go func() { | |||||
| // Client writes a bunch of stuff, and does a key | |||||
| // change in the middle. This should not confuse the | |||||
| // handshake in progress | |||||
| for i := 0; i < 10; i++ { | |||||
| p := []byte{msgRequestSuccess, byte(i)} | |||||
| if err := trC.writePacket(p); err != nil { | |||||
| t.Fatalf("sendPacket: %v", err) | |||||
| } | |||||
| if i == 5 { | |||||
| // halfway through, we request a key change. | |||||
| _, _, err := trC.sendKexInit() | |||||
| if err != nil { | |||||
| t.Fatalf("sendKexInit: %v", err) | |||||
| } | |||||
| } | |||||
| } | |||||
| trC.Close() | |||||
| }() | |||||
| // Server checks that client messages come in cleanly | |||||
| i := 0 | |||||
| for { | |||||
| p, err := trS.readPacket() | |||||
| if err != nil { | |||||
| break | |||||
| } | |||||
| if p[0] == msgNewKeys { | |||||
| continue | |||||
| } | |||||
| want := []byte{msgRequestSuccess, byte(i)} | |||||
| if bytes.Compare(p, want) != 0 { | |||||
| t.Errorf("message %d: got %q, want %q", i, p, want) | |||||
| } | |||||
| i++ | |||||
| } | |||||
| if i != 10 { | |||||
| t.Errorf("received %d messages, want 10.", i) | |||||
| } | |||||
| // If all went well, we registered exactly 1 key change. | |||||
| if len(checker.calls) != 1 { | |||||
| t.Fatalf("got %d host key checks, want 1", len(checker.calls)) | |||||
| } | |||||
| pub := testSigners["ecdsa"].PublicKey() | |||||
| want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal()) | |||||
| if want != checker.calls[0] { | |||||
| t.Errorf("got %q want %q for host key check", checker.calls[0], want) | |||||
| } | |||||
| } | |||||
| func TestHandshakeError(t *testing.T) { | |||||
| checker := &testChecker{} | |||||
| trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad") | |||||
| if err != nil { | |||||
| t.Fatalf("handshakePair: %v", err) | |||||
| } | |||||
| defer trC.Close() | |||||
| defer trS.Close() | |||||
| // send a packet | |||||
| packet := []byte{msgRequestSuccess, 42} | |||||
| if err := trC.writePacket(packet); err != nil { | |||||
| t.Errorf("writePacket: %v", err) | |||||
| } | |||||
| // Now request a key change. | |||||
| _, _, err = trC.sendKexInit() | |||||
| if err != nil { | |||||
| t.Errorf("sendKexInit: %v", err) | |||||
| } | |||||
| // the key change will fail, and afterwards we can't write. | |||||
| if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil { | |||||
| t.Errorf("writePacket after botched rekey succeeded.") | |||||
| } | |||||
| readback, err := trS.readPacket() | |||||
| if err != nil { | |||||
| t.Fatalf("server closed too soon: %v", err) | |||||
| } | |||||
| if bytes.Compare(readback, packet) != 0 { | |||||
| t.Errorf("got %q want %q", readback, packet) | |||||
| } | |||||
| readback, err = trS.readPacket() | |||||
| if err == nil { | |||||
| t.Errorf("got a message %q after failed key change", readback) | |||||
| } | |||||
| } | |||||
| func TestHandshakeTwice(t *testing.T) { | |||||
| checker := &testChecker{} | |||||
| trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") | |||||
| if err != nil { | |||||
| t.Fatalf("handshakePair: %v", err) | |||||
| } | |||||
| defer trC.Close() | |||||
| defer trS.Close() | |||||
| // send a packet | |||||
| packet := make([]byte, 5) | |||||
| packet[0] = msgRequestSuccess | |||||
| if err := trC.writePacket(packet); err != nil { | |||||
| t.Errorf("writePacket: %v", err) | |||||
| } | |||||
| // Now request a key change. | |||||
| _, _, err = trC.sendKexInit() | |||||
| if err != nil { | |||||
| t.Errorf("sendKexInit: %v", err) | |||||
| } | |||||
| // Send another packet. Use a fresh one, since writePacket destroys. | |||||
| packet = make([]byte, 5) | |||||
| packet[0] = msgRequestSuccess | |||||
| if err := trC.writePacket(packet); err != nil { | |||||
| t.Errorf("writePacket: %v", err) | |||||
| } | |||||
| // 2nd key change. | |||||
| _, _, err = trC.sendKexInit() | |||||
| if err != nil { | |||||
| t.Errorf("sendKexInit: %v", err) | |||||
| } | |||||
| packet = make([]byte, 5) | |||||
| packet[0] = msgRequestSuccess | |||||
| if err := trC.writePacket(packet); err != nil { | |||||
| t.Errorf("writePacket: %v", err) | |||||
| } | |||||
| packet = make([]byte, 5) | |||||
| packet[0] = msgRequestSuccess | |||||
| for i := 0; i < 5; i++ { | |||||
| msg, err := trS.readPacket() | |||||
| if err != nil { | |||||
| t.Fatalf("server closed too soon: %v", err) | |||||
| } | |||||
| if msg[0] == msgNewKeys { | |||||
| continue | |||||
| } | |||||
| if bytes.Compare(msg, packet) != 0 { | |||||
| t.Errorf("packet %d: got %q want %q", i, msg, packet) | |||||
| } | |||||
| } | |||||
| if len(checker.calls) != 2 { | |||||
| t.Errorf("got %d key changes, want 2", len(checker.calls)) | |||||
| } | |||||
| } | |||||
| func TestHandshakeAutoRekeyWrite(t *testing.T) { | |||||
| checker := &testChecker{} | |||||
| clientConf := &ClientConfig{HostKeyCallback: checker.Check} | |||||
| clientConf.RekeyThreshold = 500 | |||||
| trC, trS, err := handshakePair(clientConf, "addr") | |||||
| if err != nil { | |||||
| t.Fatalf("handshakePair: %v", err) | |||||
| } | |||||
| defer trC.Close() | |||||
| defer trS.Close() | |||||
| for i := 0; i < 5; i++ { | |||||
| packet := make([]byte, 251) | |||||
| packet[0] = msgRequestSuccess | |||||
| if err := trC.writePacket(packet); err != nil { | |||||
| t.Errorf("writePacket: %v", err) | |||||
| } | |||||
| } | |||||
| j := 0 | |||||
| for ; j < 5; j++ { | |||||
| _, err := trS.readPacket() | |||||
| if err != nil { | |||||
| break | |||||
| } | |||||
| } | |||||
| if j != 5 { | |||||
| t.Errorf("got %d, want 5 messages", j) | |||||
| } | |||||
| if len(checker.calls) != 2 { | |||||
| t.Errorf("got %d key changes, wanted 2", len(checker.calls)) | |||||
| } | |||||
| } | |||||
| type syncChecker struct { | |||||
| called chan int | |||||
| } | |||||
| func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { | |||||
| t.called <- 1 | |||||
| return nil | |||||
| } | |||||
| func TestHandshakeAutoRekeyRead(t *testing.T) { | |||||
| sync := &syncChecker{make(chan int, 2)} | |||||
| clientConf := &ClientConfig{ | |||||
| HostKeyCallback: sync.Check, | |||||
| } | |||||
| clientConf.RekeyThreshold = 500 | |||||
| trC, trS, err := handshakePair(clientConf, "addr") | |||||
| if err != nil { | |||||
| t.Fatalf("handshakePair: %v", err) | |||||
| } | |||||
| defer trC.Close() | |||||
| defer trS.Close() | |||||
| packet := make([]byte, 501) | |||||
| packet[0] = msgRequestSuccess | |||||
| if err := trS.writePacket(packet); err != nil { | |||||
| t.Fatalf("writePacket: %v", err) | |||||
| } | |||||
| // While we read out the packet, a key change will be | |||||
| // initiated. | |||||
| if _, err := trC.readPacket(); err != nil { | |||||
| t.Fatalf("readPacket(client): %v", err) | |||||
| } | |||||
| <-sync.called | |||||
| } | |||||
| // errorKeyingTransport generates errors after a given number of | |||||
| // read/write operations. | |||||
| type errorKeyingTransport struct { | |||||
| packetConn | |||||
| readLeft, writeLeft int | |||||
| } | |||||
| func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error { | |||||
| return nil | |||||
| } | |||||
| func (n *errorKeyingTransport) getSessionID() []byte { | |||||
| return nil | |||||
| } | |||||
| func (n *errorKeyingTransport) writePacket(packet []byte) error { | |||||
| if n.writeLeft == 0 { | |||||
| n.Close() | |||||
| return errors.New("barf") | |||||
| } | |||||
| n.writeLeft-- | |||||
| return n.packetConn.writePacket(packet) | |||||
| } | |||||
| func (n *errorKeyingTransport) readPacket() ([]byte, error) { | |||||
| if n.readLeft == 0 { | |||||
| n.Close() | |||||
| return nil, errors.New("barf") | |||||
| } | |||||
| n.readLeft-- | |||||
| return n.packetConn.readPacket() | |||||
| } | |||||
| func TestHandshakeErrorHandlingRead(t *testing.T) { | |||||
| for i := 0; i < 20; i++ { | |||||
| testHandshakeErrorHandlingN(t, i, -1) | |||||
| } | |||||
| } | |||||
| func TestHandshakeErrorHandlingWrite(t *testing.T) { | |||||
| for i := 0; i < 20; i++ { | |||||
| testHandshakeErrorHandlingN(t, -1, i) | |||||
| } | |||||
| } | |||||
| // testHandshakeErrorHandlingN runs handshakes, injecting errors. If | |||||
| // handshakeTransport deadlocks, the go runtime will detect it and | |||||
| // panic. | |||||
| func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) { | |||||
| msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)}) | |||||
| a, b := memPipe() | |||||
| defer a.Close() | |||||
| defer b.Close() | |||||
| key := testSigners["ecdsa"] | |||||
| serverConf := Config{RekeyThreshold: minRekeyThreshold} | |||||
| serverConf.SetDefaults() | |||||
| serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'}) | |||||
| serverConn.hostKeys = []Signer{key} | |||||
| go serverConn.readLoop() | |||||
| clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold} | |||||
| clientConf.SetDefaults() | |||||
| clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'}) | |||||
| clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()} | |||||
| go clientConn.readLoop() | |||||
| var wg sync.WaitGroup | |||||
| wg.Add(4) | |||||
| for _, hs := range []packetConn{serverConn, clientConn} { | |||||
| go func(c packetConn) { | |||||
| for { | |||||
| err := c.writePacket(msg) | |||||
| if err != nil { | |||||
| break | |||||
| } | |||||
| } | |||||
| wg.Done() | |||||
| }(hs) | |||||
| go func(c packetConn) { | |||||
| for { | |||||
| _, err := c.readPacket() | |||||
| if err != nil { | |||||
| break | |||||
| } | |||||
| } | |||||
| wg.Done() | |||||
| }(hs) | |||||
| } | |||||
| wg.Wait() | |||||
| } | |||||
| @@ -1,526 +0,0 @@ | |||||
| // Copyright 2013 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "crypto" | |||||
| "crypto/ecdsa" | |||||
| "crypto/elliptic" | |||||
| "crypto/subtle" | |||||
| "crypto/rand" | |||||
| "errors" | |||||
| "io" | |||||
| "math/big" | |||||
| "golang.org/x/crypto/curve25519" | |||||
| ) | |||||
| const ( | |||||
| kexAlgoDH1SHA1 = "diffie-hellman-group1-sha1" | |||||
| kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1" | |||||
| kexAlgoECDH256 = "ecdh-sha2-nistp256" | |||||
| kexAlgoECDH384 = "ecdh-sha2-nistp384" | |||||
| kexAlgoECDH521 = "ecdh-sha2-nistp521" | |||||
| kexAlgoCurve25519SHA256 = "curve25519-sha256@libssh.org" | |||||
| ) | |||||
| // kexResult captures the outcome of a key exchange. | |||||
| type kexResult struct { | |||||
| // Session hash. See also RFC 4253, section 8. | |||||
| H []byte | |||||
| // Shared secret. See also RFC 4253, section 8. | |||||
| K []byte | |||||
| // Host key as hashed into H. | |||||
| HostKey []byte | |||||
| // Signature of H. | |||||
| Signature []byte | |||||
| // A cryptographic hash function that matches the security | |||||
| // level of the key exchange algorithm. It is used for | |||||
| // calculating H, and for deriving keys from H and K. | |||||
| Hash crypto.Hash | |||||
| // The session ID, which is the first H computed. This is used | |||||
| // to signal data inside transport. | |||||
| SessionID []byte | |||||
| } | |||||
| // handshakeMagics contains data that is always included in the | |||||
| // session hash. | |||||
| type handshakeMagics struct { | |||||
| clientVersion, serverVersion []byte | |||||
| clientKexInit, serverKexInit []byte | |||||
| } | |||||
| func (m *handshakeMagics) write(w io.Writer) { | |||||
| writeString(w, m.clientVersion) | |||||
| writeString(w, m.serverVersion) | |||||
| writeString(w, m.clientKexInit) | |||||
| writeString(w, m.serverKexInit) | |||||
| } | |||||
| // kexAlgorithm abstracts different key exchange algorithms. | |||||
| type kexAlgorithm interface { | |||||
| // Server runs server-side key agreement, signing the result | |||||
| // with a hostkey. | |||||
| Server(p packetConn, rand io.Reader, magics *handshakeMagics, s Signer) (*kexResult, error) | |||||
| // Client runs the client-side key agreement. Caller is | |||||
| // responsible for verifying the host key signature. | |||||
| Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) | |||||
| } | |||||
| // dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement. | |||||
| type dhGroup struct { | |||||
| g, p *big.Int | |||||
| } | |||||
| func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) { | |||||
| if theirPublic.Sign() <= 0 || theirPublic.Cmp(group.p) >= 0 { | |||||
| return nil, errors.New("ssh: DH parameter out of bounds") | |||||
| } | |||||
| return new(big.Int).Exp(theirPublic, myPrivate, group.p), nil | |||||
| } | |||||
| func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handshakeMagics) (*kexResult, error) { | |||||
| hashFunc := crypto.SHA1 | |||||
| x, err := rand.Int(randSource, group.p) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| X := new(big.Int).Exp(group.g, x, group.p) | |||||
| kexDHInit := kexDHInitMsg{ | |||||
| X: X, | |||||
| } | |||||
| if err := c.writePacket(Marshal(&kexDHInit)); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| packet, err := c.readPacket() | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| var kexDHReply kexDHReplyMsg | |||||
| if err = Unmarshal(packet, &kexDHReply); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| kInt, err := group.diffieHellman(kexDHReply.Y, x) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| h := hashFunc.New() | |||||
| magics.write(h) | |||||
| writeString(h, kexDHReply.HostKey) | |||||
| writeInt(h, X) | |||||
| writeInt(h, kexDHReply.Y) | |||||
| K := make([]byte, intLength(kInt)) | |||||
| marshalInt(K, kInt) | |||||
| h.Write(K) | |||||
| return &kexResult{ | |||||
| H: h.Sum(nil), | |||||
| K: K, | |||||
| HostKey: kexDHReply.HostKey, | |||||
| Signature: kexDHReply.Signature, | |||||
| Hash: crypto.SHA1, | |||||
| }, nil | |||||
| } | |||||
| func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { | |||||
| hashFunc := crypto.SHA1 | |||||
| packet, err := c.readPacket() | |||||
| if err != nil { | |||||
| return | |||||
| } | |||||
| var kexDHInit kexDHInitMsg | |||||
| if err = Unmarshal(packet, &kexDHInit); err != nil { | |||||
| return | |||||
| } | |||||
| y, err := rand.Int(randSource, group.p) | |||||
| if err != nil { | |||||
| return | |||||
| } | |||||
| Y := new(big.Int).Exp(group.g, y, group.p) | |||||
| kInt, err := group.diffieHellman(kexDHInit.X, y) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| hostKeyBytes := priv.PublicKey().Marshal() | |||||
| h := hashFunc.New() | |||||
| magics.write(h) | |||||
| writeString(h, hostKeyBytes) | |||||
| writeInt(h, kexDHInit.X) | |||||
| writeInt(h, Y) | |||||
| K := make([]byte, intLength(kInt)) | |||||
| marshalInt(K, kInt) | |||||
| h.Write(K) | |||||
| H := h.Sum(nil) | |||||
| // H is already a hash, but the hostkey signing will apply its | |||||
| // own key-specific hash algorithm. | |||||
| sig, err := signAndMarshal(priv, randSource, H) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| kexDHReply := kexDHReplyMsg{ | |||||
| HostKey: hostKeyBytes, | |||||
| Y: Y, | |||||
| Signature: sig, | |||||
| } | |||||
| packet = Marshal(&kexDHReply) | |||||
| err = c.writePacket(packet) | |||||
| return &kexResult{ | |||||
| H: H, | |||||
| K: K, | |||||
| HostKey: hostKeyBytes, | |||||
| Signature: sig, | |||||
| Hash: crypto.SHA1, | |||||
| }, nil | |||||
| } | |||||
| // ecdh performs Elliptic Curve Diffie-Hellman key exchange as | |||||
| // described in RFC 5656, section 4. | |||||
| type ecdh struct { | |||||
| curve elliptic.Curve | |||||
| } | |||||
| func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { | |||||
| ephKey, err := ecdsa.GenerateKey(kex.curve, rand) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| kexInit := kexECDHInitMsg{ | |||||
| ClientPubKey: elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y), | |||||
| } | |||||
| serialized := Marshal(&kexInit) | |||||
| if err := c.writePacket(serialized); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| packet, err := c.readPacket() | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| var reply kexECDHReplyMsg | |||||
| if err = Unmarshal(packet, &reply); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| x, y, err := unmarshalECKey(kex.curve, reply.EphemeralPubKey) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| // generate shared secret | |||||
| secret, _ := kex.curve.ScalarMult(x, y, ephKey.D.Bytes()) | |||||
| h := ecHash(kex.curve).New() | |||||
| magics.write(h) | |||||
| writeString(h, reply.HostKey) | |||||
| writeString(h, kexInit.ClientPubKey) | |||||
| writeString(h, reply.EphemeralPubKey) | |||||
| K := make([]byte, intLength(secret)) | |||||
| marshalInt(K, secret) | |||||
| h.Write(K) | |||||
| return &kexResult{ | |||||
| H: h.Sum(nil), | |||||
| K: K, | |||||
| HostKey: reply.HostKey, | |||||
| Signature: reply.Signature, | |||||
| Hash: ecHash(kex.curve), | |||||
| }, nil | |||||
| } | |||||
| // unmarshalECKey parses and checks an EC key. | |||||
| func unmarshalECKey(curve elliptic.Curve, pubkey []byte) (x, y *big.Int, err error) { | |||||
| x, y = elliptic.Unmarshal(curve, pubkey) | |||||
| if x == nil { | |||||
| return nil, nil, errors.New("ssh: elliptic.Unmarshal failure") | |||||
| } | |||||
| if !validateECPublicKey(curve, x, y) { | |||||
| return nil, nil, errors.New("ssh: public key not on curve") | |||||
| } | |||||
| return x, y, nil | |||||
| } | |||||
| // validateECPublicKey checks that the point is a valid public key for | |||||
| // the given curve. See [SEC1], 3.2.2 | |||||
| func validateECPublicKey(curve elliptic.Curve, x, y *big.Int) bool { | |||||
| if x.Sign() == 0 && y.Sign() == 0 { | |||||
| return false | |||||
| } | |||||
| if x.Cmp(curve.Params().P) >= 0 { | |||||
| return false | |||||
| } | |||||
| if y.Cmp(curve.Params().P) >= 0 { | |||||
| return false | |||||
| } | |||||
| if !curve.IsOnCurve(x, y) { | |||||
| return false | |||||
| } | |||||
| // We don't check if N * PubKey == 0, since | |||||
| // | |||||
| // - the NIST curves have cofactor = 1, so this is implicit. | |||||
| // (We don't foresee an implementation that supports non NIST | |||||
| // curves) | |||||
| // | |||||
| // - for ephemeral keys, we don't need to worry about small | |||||
| // subgroup attacks. | |||||
| return true | |||||
| } | |||||
| func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { | |||||
| packet, err := c.readPacket() | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| var kexECDHInit kexECDHInitMsg | |||||
| if err = Unmarshal(packet, &kexECDHInit); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| clientX, clientY, err := unmarshalECKey(kex.curve, kexECDHInit.ClientPubKey) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| // We could cache this key across multiple users/multiple | |||||
| // connection attempts, but the benefit is small. OpenSSH | |||||
| // generates a new key for each incoming connection. | |||||
| ephKey, err := ecdsa.GenerateKey(kex.curve, rand) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| hostKeyBytes := priv.PublicKey().Marshal() | |||||
| serializedEphKey := elliptic.Marshal(kex.curve, ephKey.PublicKey.X, ephKey.PublicKey.Y) | |||||
| // generate shared secret | |||||
| secret, _ := kex.curve.ScalarMult(clientX, clientY, ephKey.D.Bytes()) | |||||
| h := ecHash(kex.curve).New() | |||||
| magics.write(h) | |||||
| writeString(h, hostKeyBytes) | |||||
| writeString(h, kexECDHInit.ClientPubKey) | |||||
| writeString(h, serializedEphKey) | |||||
| K := make([]byte, intLength(secret)) | |||||
| marshalInt(K, secret) | |||||
| h.Write(K) | |||||
| H := h.Sum(nil) | |||||
| // H is already a hash, but the hostkey signing will apply its | |||||
| // own key-specific hash algorithm. | |||||
| sig, err := signAndMarshal(priv, rand, H) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| reply := kexECDHReplyMsg{ | |||||
| EphemeralPubKey: serializedEphKey, | |||||
| HostKey: hostKeyBytes, | |||||
| Signature: sig, | |||||
| } | |||||
| serialized := Marshal(&reply) | |||||
| if err := c.writePacket(serialized); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return &kexResult{ | |||||
| H: H, | |||||
| K: K, | |||||
| HostKey: reply.HostKey, | |||||
| Signature: sig, | |||||
| Hash: ecHash(kex.curve), | |||||
| }, nil | |||||
| } | |||||
| var kexAlgoMap = map[string]kexAlgorithm{} | |||||
| func init() { | |||||
| // This is the group called diffie-hellman-group1-sha1 in RFC | |||||
| // 4253 and Oakley Group 2 in RFC 2409. | |||||
| p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16) | |||||
| kexAlgoMap[kexAlgoDH1SHA1] = &dhGroup{ | |||||
| g: new(big.Int).SetInt64(2), | |||||
| p: p, | |||||
| } | |||||
| // This is the group called diffie-hellman-group14-sha1 in RFC | |||||
| // 4253 and Oakley Group 14 in RFC 3526. | |||||
| p, _ = new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16) | |||||
| kexAlgoMap[kexAlgoDH14SHA1] = &dhGroup{ | |||||
| g: new(big.Int).SetInt64(2), | |||||
| p: p, | |||||
| } | |||||
| kexAlgoMap[kexAlgoECDH521] = &ecdh{elliptic.P521()} | |||||
| kexAlgoMap[kexAlgoECDH384] = &ecdh{elliptic.P384()} | |||||
| kexAlgoMap[kexAlgoECDH256] = &ecdh{elliptic.P256()} | |||||
| kexAlgoMap[kexAlgoCurve25519SHA256] = &curve25519sha256{} | |||||
| } | |||||
| // curve25519sha256 implements the curve25519-sha256@libssh.org key | |||||
| // agreement protocol, as described in | |||||
| // https://git.libssh.org/projects/libssh.git/tree/doc/curve25519-sha256@libssh.org.txt | |||||
| type curve25519sha256 struct{} | |||||
| type curve25519KeyPair struct { | |||||
| priv [32]byte | |||||
| pub [32]byte | |||||
| } | |||||
| func (kp *curve25519KeyPair) generate(rand io.Reader) error { | |||||
| if _, err := io.ReadFull(rand, kp.priv[:]); err != nil { | |||||
| return err | |||||
| } | |||||
| curve25519.ScalarBaseMult(&kp.pub, &kp.priv) | |||||
| return nil | |||||
| } | |||||
| // curve25519Zeros is just an array of 32 zero bytes so that we have something | |||||
| // convenient to compare against in order to reject curve25519 points with the | |||||
| // wrong order. | |||||
| var curve25519Zeros [32]byte | |||||
| func (kex *curve25519sha256) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error) { | |||||
| var kp curve25519KeyPair | |||||
| if err := kp.generate(rand); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| if err := c.writePacket(Marshal(&kexECDHInitMsg{kp.pub[:]})); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| packet, err := c.readPacket() | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| var reply kexECDHReplyMsg | |||||
| if err = Unmarshal(packet, &reply); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| if len(reply.EphemeralPubKey) != 32 { | |||||
| return nil, errors.New("ssh: peer's curve25519 public value has wrong length") | |||||
| } | |||||
| var servPub, secret [32]byte | |||||
| copy(servPub[:], reply.EphemeralPubKey) | |||||
| curve25519.ScalarMult(&secret, &kp.priv, &servPub) | |||||
| if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { | |||||
| return nil, errors.New("ssh: peer's curve25519 public value has wrong order") | |||||
| } | |||||
| h := crypto.SHA256.New() | |||||
| magics.write(h) | |||||
| writeString(h, reply.HostKey) | |||||
| writeString(h, kp.pub[:]) | |||||
| writeString(h, reply.EphemeralPubKey) | |||||
| kInt := new(big.Int).SetBytes(secret[:]) | |||||
| K := make([]byte, intLength(kInt)) | |||||
| marshalInt(K, kInt) | |||||
| h.Write(K) | |||||
| return &kexResult{ | |||||
| H: h.Sum(nil), | |||||
| K: K, | |||||
| HostKey: reply.HostKey, | |||||
| Signature: reply.Signature, | |||||
| Hash: crypto.SHA256, | |||||
| }, nil | |||||
| } | |||||
| func (kex *curve25519sha256) Server(c packetConn, rand io.Reader, magics *handshakeMagics, priv Signer) (result *kexResult, err error) { | |||||
| packet, err := c.readPacket() | |||||
| if err != nil { | |||||
| return | |||||
| } | |||||
| var kexInit kexECDHInitMsg | |||||
| if err = Unmarshal(packet, &kexInit); err != nil { | |||||
| return | |||||
| } | |||||
| if len(kexInit.ClientPubKey) != 32 { | |||||
| return nil, errors.New("ssh: peer's curve25519 public value has wrong length") | |||||
| } | |||||
| var kp curve25519KeyPair | |||||
| if err := kp.generate(rand); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| var clientPub, secret [32]byte | |||||
| copy(clientPub[:], kexInit.ClientPubKey) | |||||
| curve25519.ScalarMult(&secret, &kp.priv, &clientPub) | |||||
| if subtle.ConstantTimeCompare(secret[:], curve25519Zeros[:]) == 1 { | |||||
| return nil, errors.New("ssh: peer's curve25519 public value has wrong order") | |||||
| } | |||||
| hostKeyBytes := priv.PublicKey().Marshal() | |||||
| h := crypto.SHA256.New() | |||||
| magics.write(h) | |||||
| writeString(h, hostKeyBytes) | |||||
| writeString(h, kexInit.ClientPubKey) | |||||
| writeString(h, kp.pub[:]) | |||||
| kInt := new(big.Int).SetBytes(secret[:]) | |||||
| K := make([]byte, intLength(kInt)) | |||||
| marshalInt(K, kInt) | |||||
| h.Write(K) | |||||
| H := h.Sum(nil) | |||||
| sig, err := signAndMarshal(priv, rand, H) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| reply := kexECDHReplyMsg{ | |||||
| EphemeralPubKey: kp.pub[:], | |||||
| HostKey: hostKeyBytes, | |||||
| Signature: sig, | |||||
| } | |||||
| if err := c.writePacket(Marshal(&reply)); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return &kexResult{ | |||||
| H: H, | |||||
| K: K, | |||||
| HostKey: hostKeyBytes, | |||||
| Signature: sig, | |||||
| Hash: crypto.SHA256, | |||||
| }, nil | |||||
| } | |||||
| @@ -1,50 +0,0 @@ | |||||
| // Copyright 2013 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| // Key exchange tests. | |||||
| import ( | |||||
| "crypto/rand" | |||||
| "reflect" | |||||
| "testing" | |||||
| ) | |||||
| func TestKexes(t *testing.T) { | |||||
| type kexResultErr struct { | |||||
| result *kexResult | |||||
| err error | |||||
| } | |||||
| for name, kex := range kexAlgoMap { | |||||
| a, b := memPipe() | |||||
| s := make(chan kexResultErr, 1) | |||||
| c := make(chan kexResultErr, 1) | |||||
| var magics handshakeMagics | |||||
| go func() { | |||||
| r, e := kex.Client(a, rand.Reader, &magics) | |||||
| a.Close() | |||||
| c <- kexResultErr{r, e} | |||||
| }() | |||||
| go func() { | |||||
| r, e := kex.Server(b, rand.Reader, &magics, testSigners["ecdsa"]) | |||||
| b.Close() | |||||
| s <- kexResultErr{r, e} | |||||
| }() | |||||
| clientRes := <-c | |||||
| serverRes := <-s | |||||
| if clientRes.err != nil { | |||||
| t.Errorf("client: %v", clientRes.err) | |||||
| } | |||||
| if serverRes.err != nil { | |||||
| t.Errorf("server: %v", serverRes.err) | |||||
| } | |||||
| if !reflect.DeepEqual(clientRes.result, serverRes.result) { | |||||
| t.Errorf("kex %q: mismatch %#v, %#v", name, clientRes.result, serverRes.result) | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,628 +0,0 @@ | |||||
| // Copyright 2012 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "bytes" | |||||
| "crypto" | |||||
| "crypto/dsa" | |||||
| "crypto/ecdsa" | |||||
| "crypto/elliptic" | |||||
| "crypto/rsa" | |||||
| "crypto/x509" | |||||
| "encoding/asn1" | |||||
| "encoding/base64" | |||||
| "encoding/pem" | |||||
| "errors" | |||||
| "fmt" | |||||
| "io" | |||||
| "math/big" | |||||
| ) | |||||
| // These constants represent the algorithm names for key types supported by this | |||||
| // package. | |||||
| const ( | |||||
| KeyAlgoRSA = "ssh-rsa" | |||||
| KeyAlgoDSA = "ssh-dss" | |||||
| KeyAlgoECDSA256 = "ecdsa-sha2-nistp256" | |||||
| KeyAlgoECDSA384 = "ecdsa-sha2-nistp384" | |||||
| KeyAlgoECDSA521 = "ecdsa-sha2-nistp521" | |||||
| ) | |||||
| // parsePubKey parses a public key of the given algorithm. | |||||
| // Use ParsePublicKey for keys with prepended algorithm. | |||||
| func parsePubKey(in []byte, algo string) (pubKey PublicKey, rest []byte, err error) { | |||||
| switch algo { | |||||
| case KeyAlgoRSA: | |||||
| return parseRSA(in) | |||||
| case KeyAlgoDSA: | |||||
| return parseDSA(in) | |||||
| case KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521: | |||||
| return parseECDSA(in) | |||||
| case CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01: | |||||
| cert, err := parseCert(in, certToPrivAlgo(algo)) | |||||
| if err != nil { | |||||
| return nil, nil, err | |||||
| } | |||||
| return cert, nil, nil | |||||
| } | |||||
| return nil, nil, fmt.Errorf("ssh: unknown key algorithm: %v", err) | |||||
| } | |||||
| // parseAuthorizedKey parses a public key in OpenSSH authorized_keys format | |||||
| // (see sshd(8) manual page) once the options and key type fields have been | |||||
| // removed. | |||||
| func parseAuthorizedKey(in []byte) (out PublicKey, comment string, err error) { | |||||
| in = bytes.TrimSpace(in) | |||||
| i := bytes.IndexAny(in, " \t") | |||||
| if i == -1 { | |||||
| i = len(in) | |||||
| } | |||||
| base64Key := in[:i] | |||||
| key := make([]byte, base64.StdEncoding.DecodedLen(len(base64Key))) | |||||
| n, err := base64.StdEncoding.Decode(key, base64Key) | |||||
| if err != nil { | |||||
| return nil, "", err | |||||
| } | |||||
| key = key[:n] | |||||
| out, err = ParsePublicKey(key) | |||||
| if err != nil { | |||||
| return nil, "", err | |||||
| } | |||||
| comment = string(bytes.TrimSpace(in[i:])) | |||||
| return out, comment, nil | |||||
| } | |||||
| // ParseAuthorizedKeys parses a public key from an authorized_keys | |||||
| // file used in OpenSSH according to the sshd(8) manual page. | |||||
| func ParseAuthorizedKey(in []byte) (out PublicKey, comment string, options []string, rest []byte, err error) { | |||||
| for len(in) > 0 { | |||||
| end := bytes.IndexByte(in, '\n') | |||||
| if end != -1 { | |||||
| rest = in[end+1:] | |||||
| in = in[:end] | |||||
| } else { | |||||
| rest = nil | |||||
| } | |||||
| end = bytes.IndexByte(in, '\r') | |||||
| if end != -1 { | |||||
| in = in[:end] | |||||
| } | |||||
| in = bytes.TrimSpace(in) | |||||
| if len(in) == 0 || in[0] == '#' { | |||||
| in = rest | |||||
| continue | |||||
| } | |||||
| i := bytes.IndexAny(in, " \t") | |||||
| if i == -1 { | |||||
| in = rest | |||||
| continue | |||||
| } | |||||
| if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { | |||||
| return out, comment, options, rest, nil | |||||
| } | |||||
| // No key type recognised. Maybe there's an options field at | |||||
| // the beginning. | |||||
| var b byte | |||||
| inQuote := false | |||||
| var candidateOptions []string | |||||
| optionStart := 0 | |||||
| for i, b = range in { | |||||
| isEnd := !inQuote && (b == ' ' || b == '\t') | |||||
| if (b == ',' && !inQuote) || isEnd { | |||||
| if i-optionStart > 0 { | |||||
| candidateOptions = append(candidateOptions, string(in[optionStart:i])) | |||||
| } | |||||
| optionStart = i + 1 | |||||
| } | |||||
| if isEnd { | |||||
| break | |||||
| } | |||||
| if b == '"' && (i == 0 || (i > 0 && in[i-1] != '\\')) { | |||||
| inQuote = !inQuote | |||||
| } | |||||
| } | |||||
| for i < len(in) && (in[i] == ' ' || in[i] == '\t') { | |||||
| i++ | |||||
| } | |||||
| if i == len(in) { | |||||
| // Invalid line: unmatched quote | |||||
| in = rest | |||||
| continue | |||||
| } | |||||
| in = in[i:] | |||||
| i = bytes.IndexAny(in, " \t") | |||||
| if i == -1 { | |||||
| in = rest | |||||
| continue | |||||
| } | |||||
| if out, comment, err = parseAuthorizedKey(in[i:]); err == nil { | |||||
| options = candidateOptions | |||||
| return out, comment, options, rest, nil | |||||
| } | |||||
| in = rest | |||||
| continue | |||||
| } | |||||
| return nil, "", nil, nil, errors.New("ssh: no key found") | |||||
| } | |||||
| // ParsePublicKey parses an SSH public key formatted for use in | |||||
| // the SSH wire protocol according to RFC 4253, section 6.6. | |||||
| func ParsePublicKey(in []byte) (out PublicKey, err error) { | |||||
| algo, in, ok := parseString(in) | |||||
| if !ok { | |||||
| return nil, errShortRead | |||||
| } | |||||
| var rest []byte | |||||
| out, rest, err = parsePubKey(in, string(algo)) | |||||
| if len(rest) > 0 { | |||||
| return nil, errors.New("ssh: trailing junk in public key") | |||||
| } | |||||
| return out, err | |||||
| } | |||||
| // MarshalAuthorizedKey serializes key for inclusion in an OpenSSH | |||||
| // authorized_keys file. The return value ends with newline. | |||||
| func MarshalAuthorizedKey(key PublicKey) []byte { | |||||
| b := &bytes.Buffer{} | |||||
| b.WriteString(key.Type()) | |||||
| b.WriteByte(' ') | |||||
| e := base64.NewEncoder(base64.StdEncoding, b) | |||||
| e.Write(key.Marshal()) | |||||
| e.Close() | |||||
| b.WriteByte('\n') | |||||
| return b.Bytes() | |||||
| } | |||||
| // PublicKey is an abstraction of different types of public keys. | |||||
| type PublicKey interface { | |||||
| // Type returns the key's type, e.g. "ssh-rsa". | |||||
| Type() string | |||||
| // Marshal returns the serialized key data in SSH wire format, | |||||
| // with the name prefix. | |||||
| Marshal() []byte | |||||
| // Verify that sig is a signature on the given data using this | |||||
| // key. This function will hash the data appropriately first. | |||||
| Verify(data []byte, sig *Signature) error | |||||
| } | |||||
| // A Signer can create signatures that verify against a public key. | |||||
| type Signer interface { | |||||
| // PublicKey returns an associated PublicKey instance. | |||||
| PublicKey() PublicKey | |||||
| // Sign returns raw signature for the given data. This method | |||||
| // will apply the hash specified for the keytype to the data. | |||||
| Sign(rand io.Reader, data []byte) (*Signature, error) | |||||
| } | |||||
| type rsaPublicKey rsa.PublicKey | |||||
| func (r *rsaPublicKey) Type() string { | |||||
| return "ssh-rsa" | |||||
| } | |||||
| // parseRSA parses an RSA key according to RFC 4253, section 6.6. | |||||
| func parseRSA(in []byte) (out PublicKey, rest []byte, err error) { | |||||
| var w struct { | |||||
| E *big.Int | |||||
| N *big.Int | |||||
| Rest []byte `ssh:"rest"` | |||||
| } | |||||
| if err := Unmarshal(in, &w); err != nil { | |||||
| return nil, nil, err | |||||
| } | |||||
| if w.E.BitLen() > 24 { | |||||
| return nil, nil, errors.New("ssh: exponent too large") | |||||
| } | |||||
| e := w.E.Int64() | |||||
| if e < 3 || e&1 == 0 { | |||||
| return nil, nil, errors.New("ssh: incorrect exponent") | |||||
| } | |||||
| var key rsa.PublicKey | |||||
| key.E = int(e) | |||||
| key.N = w.N | |||||
| return (*rsaPublicKey)(&key), w.Rest, nil | |||||
| } | |||||
| func (r *rsaPublicKey) Marshal() []byte { | |||||
| e := new(big.Int).SetInt64(int64(r.E)) | |||||
| wirekey := struct { | |||||
| Name string | |||||
| E *big.Int | |||||
| N *big.Int | |||||
| }{ | |||||
| KeyAlgoRSA, | |||||
| e, | |||||
| r.N, | |||||
| } | |||||
| return Marshal(&wirekey) | |||||
| } | |||||
| func (r *rsaPublicKey) Verify(data []byte, sig *Signature) error { | |||||
| if sig.Format != r.Type() { | |||||
| return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, r.Type()) | |||||
| } | |||||
| h := crypto.SHA1.New() | |||||
| h.Write(data) | |||||
| digest := h.Sum(nil) | |||||
| return rsa.VerifyPKCS1v15((*rsa.PublicKey)(r), crypto.SHA1, digest, sig.Blob) | |||||
| } | |||||
| type rsaPrivateKey struct { | |||||
| *rsa.PrivateKey | |||||
| } | |||||
| func (r *rsaPrivateKey) PublicKey() PublicKey { | |||||
| return (*rsaPublicKey)(&r.PrivateKey.PublicKey) | |||||
| } | |||||
| func (r *rsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { | |||||
| h := crypto.SHA1.New() | |||||
| h.Write(data) | |||||
| digest := h.Sum(nil) | |||||
| blob, err := rsa.SignPKCS1v15(rand, r.PrivateKey, crypto.SHA1, digest) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return &Signature{ | |||||
| Format: r.PublicKey().Type(), | |||||
| Blob: blob, | |||||
| }, nil | |||||
| } | |||||
| type dsaPublicKey dsa.PublicKey | |||||
| func (r *dsaPublicKey) Type() string { | |||||
| return "ssh-dss" | |||||
| } | |||||
| // parseDSA parses an DSA key according to RFC 4253, section 6.6. | |||||
| func parseDSA(in []byte) (out PublicKey, rest []byte, err error) { | |||||
| var w struct { | |||||
| P, Q, G, Y *big.Int | |||||
| Rest []byte `ssh:"rest"` | |||||
| } | |||||
| if err := Unmarshal(in, &w); err != nil { | |||||
| return nil, nil, err | |||||
| } | |||||
| key := &dsaPublicKey{ | |||||
| Parameters: dsa.Parameters{ | |||||
| P: w.P, | |||||
| Q: w.Q, | |||||
| G: w.G, | |||||
| }, | |||||
| Y: w.Y, | |||||
| } | |||||
| return key, w.Rest, nil | |||||
| } | |||||
| func (k *dsaPublicKey) Marshal() []byte { | |||||
| w := struct { | |||||
| Name string | |||||
| P, Q, G, Y *big.Int | |||||
| }{ | |||||
| k.Type(), | |||||
| k.P, | |||||
| k.Q, | |||||
| k.G, | |||||
| k.Y, | |||||
| } | |||||
| return Marshal(&w) | |||||
| } | |||||
| func (k *dsaPublicKey) Verify(data []byte, sig *Signature) error { | |||||
| if sig.Format != k.Type() { | |||||
| return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, k.Type()) | |||||
| } | |||||
| h := crypto.SHA1.New() | |||||
| h.Write(data) | |||||
| digest := h.Sum(nil) | |||||
| // Per RFC 4253, section 6.6, | |||||
| // The value for 'dss_signature_blob' is encoded as a string containing | |||||
| // r, followed by s (which are 160-bit integers, without lengths or | |||||
| // padding, unsigned, and in network byte order). | |||||
| // For DSS purposes, sig.Blob should be exactly 40 bytes in length. | |||||
| if len(sig.Blob) != 40 { | |||||
| return errors.New("ssh: DSA signature parse error") | |||||
| } | |||||
| r := new(big.Int).SetBytes(sig.Blob[:20]) | |||||
| s := new(big.Int).SetBytes(sig.Blob[20:]) | |||||
| if dsa.Verify((*dsa.PublicKey)(k), digest, r, s) { | |||||
| return nil | |||||
| } | |||||
| return errors.New("ssh: signature did not verify") | |||||
| } | |||||
| type dsaPrivateKey struct { | |||||
| *dsa.PrivateKey | |||||
| } | |||||
| func (k *dsaPrivateKey) PublicKey() PublicKey { | |||||
| return (*dsaPublicKey)(&k.PrivateKey.PublicKey) | |||||
| } | |||||
| func (k *dsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { | |||||
| h := crypto.SHA1.New() | |||||
| h.Write(data) | |||||
| digest := h.Sum(nil) | |||||
| r, s, err := dsa.Sign(rand, k.PrivateKey, digest) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| sig := make([]byte, 40) | |||||
| rb := r.Bytes() | |||||
| sb := s.Bytes() | |||||
| copy(sig[20-len(rb):20], rb) | |||||
| copy(sig[40-len(sb):], sb) | |||||
| return &Signature{ | |||||
| Format: k.PublicKey().Type(), | |||||
| Blob: sig, | |||||
| }, nil | |||||
| } | |||||
| type ecdsaPublicKey ecdsa.PublicKey | |||||
| func (key *ecdsaPublicKey) Type() string { | |||||
| return "ecdsa-sha2-" + key.nistID() | |||||
| } | |||||
| func (key *ecdsaPublicKey) nistID() string { | |||||
| switch key.Params().BitSize { | |||||
| case 256: | |||||
| return "nistp256" | |||||
| case 384: | |||||
| return "nistp384" | |||||
| case 521: | |||||
| return "nistp521" | |||||
| } | |||||
| panic("ssh: unsupported ecdsa key size") | |||||
| } | |||||
| func supportedEllipticCurve(curve elliptic.Curve) bool { | |||||
| return curve == elliptic.P256() || curve == elliptic.P384() || curve == elliptic.P521() | |||||
| } | |||||
| // ecHash returns the hash to match the given elliptic curve, see RFC | |||||
| // 5656, section 6.2.1 | |||||
| func ecHash(curve elliptic.Curve) crypto.Hash { | |||||
| bitSize := curve.Params().BitSize | |||||
| switch { | |||||
| case bitSize <= 256: | |||||
| return crypto.SHA256 | |||||
| case bitSize <= 384: | |||||
| return crypto.SHA384 | |||||
| } | |||||
| return crypto.SHA512 | |||||
| } | |||||
| // parseECDSA parses an ECDSA key according to RFC 5656, section 3.1. | |||||
| func parseECDSA(in []byte) (out PublicKey, rest []byte, err error) { | |||||
| var w struct { | |||||
| Curve string | |||||
| KeyBytes []byte | |||||
| Rest []byte `ssh:"rest"` | |||||
| } | |||||
| if err := Unmarshal(in, &w); err != nil { | |||||
| return nil, nil, err | |||||
| } | |||||
| key := new(ecdsa.PublicKey) | |||||
| switch w.Curve { | |||||
| case "nistp256": | |||||
| key.Curve = elliptic.P256() | |||||
| case "nistp384": | |||||
| key.Curve = elliptic.P384() | |||||
| case "nistp521": | |||||
| key.Curve = elliptic.P521() | |||||
| default: | |||||
| return nil, nil, errors.New("ssh: unsupported curve") | |||||
| } | |||||
| key.X, key.Y = elliptic.Unmarshal(key.Curve, w.KeyBytes) | |||||
| if key.X == nil || key.Y == nil { | |||||
| return nil, nil, errors.New("ssh: invalid curve point") | |||||
| } | |||||
| return (*ecdsaPublicKey)(key), w.Rest, nil | |||||
| } | |||||
| func (key *ecdsaPublicKey) Marshal() []byte { | |||||
| // See RFC 5656, section 3.1. | |||||
| keyBytes := elliptic.Marshal(key.Curve, key.X, key.Y) | |||||
| w := struct { | |||||
| Name string | |||||
| ID string | |||||
| Key []byte | |||||
| }{ | |||||
| key.Type(), | |||||
| key.nistID(), | |||||
| keyBytes, | |||||
| } | |||||
| return Marshal(&w) | |||||
| } | |||||
| func (key *ecdsaPublicKey) Verify(data []byte, sig *Signature) error { | |||||
| if sig.Format != key.Type() { | |||||
| return fmt.Errorf("ssh: signature type %s for key type %s", sig.Format, key.Type()) | |||||
| } | |||||
| h := ecHash(key.Curve).New() | |||||
| h.Write(data) | |||||
| digest := h.Sum(nil) | |||||
| // Per RFC 5656, section 3.1.2, | |||||
| // The ecdsa_signature_blob value has the following specific encoding: | |||||
| // mpint r | |||||
| // mpint s | |||||
| var ecSig struct { | |||||
| R *big.Int | |||||
| S *big.Int | |||||
| } | |||||
| if err := Unmarshal(sig.Blob, &ecSig); err != nil { | |||||
| return err | |||||
| } | |||||
| if ecdsa.Verify((*ecdsa.PublicKey)(key), digest, ecSig.R, ecSig.S) { | |||||
| return nil | |||||
| } | |||||
| return errors.New("ssh: signature did not verify") | |||||
| } | |||||
| type ecdsaPrivateKey struct { | |||||
| *ecdsa.PrivateKey | |||||
| } | |||||
| func (k *ecdsaPrivateKey) PublicKey() PublicKey { | |||||
| return (*ecdsaPublicKey)(&k.PrivateKey.PublicKey) | |||||
| } | |||||
| func (k *ecdsaPrivateKey) Sign(rand io.Reader, data []byte) (*Signature, error) { | |||||
| h := ecHash(k.PrivateKey.PublicKey.Curve).New() | |||||
| h.Write(data) | |||||
| digest := h.Sum(nil) | |||||
| r, s, err := ecdsa.Sign(rand, k.PrivateKey, digest) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| sig := make([]byte, intLength(r)+intLength(s)) | |||||
| rest := marshalInt(sig, r) | |||||
| marshalInt(rest, s) | |||||
| return &Signature{ | |||||
| Format: k.PublicKey().Type(), | |||||
| Blob: sig, | |||||
| }, nil | |||||
| } | |||||
| // NewSignerFromKey takes a pointer to rsa, dsa or ecdsa PrivateKey | |||||
| // returns a corresponding Signer instance. EC keys should use P256, | |||||
| // P384 or P521. | |||||
| func NewSignerFromKey(k interface{}) (Signer, error) { | |||||
| var sshKey Signer | |||||
| switch t := k.(type) { | |||||
| case *rsa.PrivateKey: | |||||
| sshKey = &rsaPrivateKey{t} | |||||
| case *dsa.PrivateKey: | |||||
| sshKey = &dsaPrivateKey{t} | |||||
| case *ecdsa.PrivateKey: | |||||
| if !supportedEllipticCurve(t.Curve) { | |||||
| return nil, errors.New("ssh: only P256, P384 and P521 EC keys are supported.") | |||||
| } | |||||
| sshKey = &ecdsaPrivateKey{t} | |||||
| default: | |||||
| return nil, fmt.Errorf("ssh: unsupported key type %T", k) | |||||
| } | |||||
| return sshKey, nil | |||||
| } | |||||
| // NewPublicKey takes a pointer to rsa, dsa or ecdsa PublicKey | |||||
| // and returns a corresponding ssh PublicKey instance. EC keys should use P256, P384 or P521. | |||||
| func NewPublicKey(k interface{}) (PublicKey, error) { | |||||
| var sshKey PublicKey | |||||
| switch t := k.(type) { | |||||
| case *rsa.PublicKey: | |||||
| sshKey = (*rsaPublicKey)(t) | |||||
| case *ecdsa.PublicKey: | |||||
| if !supportedEllipticCurve(t.Curve) { | |||||
| return nil, errors.New("ssh: only P256, P384 and P521 EC keys are supported.") | |||||
| } | |||||
| sshKey = (*ecdsaPublicKey)(t) | |||||
| case *dsa.PublicKey: | |||||
| sshKey = (*dsaPublicKey)(t) | |||||
| default: | |||||
| return nil, fmt.Errorf("ssh: unsupported key type %T", k) | |||||
| } | |||||
| return sshKey, nil | |||||
| } | |||||
| // ParsePrivateKey returns a Signer from a PEM encoded private key. It supports | |||||
| // the same keys as ParseRawPrivateKey. | |||||
| func ParsePrivateKey(pemBytes []byte) (Signer, error) { | |||||
| key, err := ParseRawPrivateKey(pemBytes) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return NewSignerFromKey(key) | |||||
| } | |||||
| // ParseRawPrivateKey returns a private key from a PEM encoded private key. It | |||||
| // supports RSA (PKCS#1), DSA (OpenSSL), and ECDSA private keys. | |||||
| func ParseRawPrivateKey(pemBytes []byte) (interface{}, error) { | |||||
| block, _ := pem.Decode(pemBytes) | |||||
| if block == nil { | |||||
| return nil, errors.New("ssh: no key found") | |||||
| } | |||||
| switch block.Type { | |||||
| case "RSA PRIVATE KEY": | |||||
| return x509.ParsePKCS1PrivateKey(block.Bytes) | |||||
| case "EC PRIVATE KEY": | |||||
| return x509.ParseECPrivateKey(block.Bytes) | |||||
| case "DSA PRIVATE KEY": | |||||
| return ParseDSAPrivateKey(block.Bytes) | |||||
| default: | |||||
| return nil, fmt.Errorf("ssh: unsupported key type %q", block.Type) | |||||
| } | |||||
| } | |||||
| // ParseDSAPrivateKey returns a DSA private key from its ASN.1 DER encoding, as | |||||
| // specified by the OpenSSL DSA man page. | |||||
| func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) { | |||||
| var k struct { | |||||
| Version int | |||||
| P *big.Int | |||||
| Q *big.Int | |||||
| G *big.Int | |||||
| Priv *big.Int | |||||
| Pub *big.Int | |||||
| } | |||||
| rest, err := asn1.Unmarshal(der, &k) | |||||
| if err != nil { | |||||
| return nil, errors.New("ssh: failed to parse DSA key: " + err.Error()) | |||||
| } | |||||
| if len(rest) > 0 { | |||||
| return nil, errors.New("ssh: garbage after DSA key") | |||||
| } | |||||
| return &dsa.PrivateKey{ | |||||
| PublicKey: dsa.PublicKey{ | |||||
| Parameters: dsa.Parameters{ | |||||
| P: k.P, | |||||
| Q: k.Q, | |||||
| G: k.G, | |||||
| }, | |||||
| Y: k.Priv, | |||||
| }, | |||||
| X: k.Pub, | |||||
| }, nil | |||||
| } | |||||
| @@ -1,306 +0,0 @@ | |||||
| // Copyright 2014 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "bytes" | |||||
| "crypto/dsa" | |||||
| "crypto/ecdsa" | |||||
| "crypto/elliptic" | |||||
| "crypto/rand" | |||||
| "crypto/rsa" | |||||
| "encoding/base64" | |||||
| "fmt" | |||||
| "reflect" | |||||
| "strings" | |||||
| "testing" | |||||
| "github.com/gogits/gogs/modules/crypto/ssh/testdata" | |||||
| ) | |||||
| func rawKey(pub PublicKey) interface{} { | |||||
| switch k := pub.(type) { | |||||
| case *rsaPublicKey: | |||||
| return (*rsa.PublicKey)(k) | |||||
| case *dsaPublicKey: | |||||
| return (*dsa.PublicKey)(k) | |||||
| case *ecdsaPublicKey: | |||||
| return (*ecdsa.PublicKey)(k) | |||||
| case *Certificate: | |||||
| return k | |||||
| } | |||||
| panic("unknown key type") | |||||
| } | |||||
| func TestKeyMarshalParse(t *testing.T) { | |||||
| for _, priv := range testSigners { | |||||
| pub := priv.PublicKey() | |||||
| roundtrip, err := ParsePublicKey(pub.Marshal()) | |||||
| if err != nil { | |||||
| t.Errorf("ParsePublicKey(%T): %v", pub, err) | |||||
| } | |||||
| k1 := rawKey(pub) | |||||
| k2 := rawKey(roundtrip) | |||||
| if !reflect.DeepEqual(k1, k2) { | |||||
| t.Errorf("got %#v in roundtrip, want %#v", k2, k1) | |||||
| } | |||||
| } | |||||
| } | |||||
| func TestUnsupportedCurves(t *testing.T) { | |||||
| raw, err := ecdsa.GenerateKey(elliptic.P224(), rand.Reader) | |||||
| if err != nil { | |||||
| t.Fatalf("GenerateKey: %v", err) | |||||
| } | |||||
| if _, err = NewSignerFromKey(raw); err == nil || !strings.Contains(err.Error(), "only P256") { | |||||
| t.Fatalf("NewPrivateKey should not succeed with P224, got: %v", err) | |||||
| } | |||||
| if _, err = NewPublicKey(&raw.PublicKey); err == nil || !strings.Contains(err.Error(), "only P256") { | |||||
| t.Fatalf("NewPublicKey should not succeed with P224, got: %v", err) | |||||
| } | |||||
| } | |||||
| func TestNewPublicKey(t *testing.T) { | |||||
| for _, k := range testSigners { | |||||
| raw := rawKey(k.PublicKey()) | |||||
| // Skip certificates, as NewPublicKey does not support them. | |||||
| if _, ok := raw.(*Certificate); ok { | |||||
| continue | |||||
| } | |||||
| pub, err := NewPublicKey(raw) | |||||
| if err != nil { | |||||
| t.Errorf("NewPublicKey(%#v): %v", raw, err) | |||||
| } | |||||
| if !reflect.DeepEqual(k.PublicKey(), pub) { | |||||
| t.Errorf("NewPublicKey(%#v) = %#v, want %#v", raw, pub, k.PublicKey()) | |||||
| } | |||||
| } | |||||
| } | |||||
| func TestKeySignVerify(t *testing.T) { | |||||
| for _, priv := range testSigners { | |||||
| pub := priv.PublicKey() | |||||
| data := []byte("sign me") | |||||
| sig, err := priv.Sign(rand.Reader, data) | |||||
| if err != nil { | |||||
| t.Fatalf("Sign(%T): %v", priv, err) | |||||
| } | |||||
| if err := pub.Verify(data, sig); err != nil { | |||||
| t.Errorf("publicKey.Verify(%T): %v", priv, err) | |||||
| } | |||||
| sig.Blob[5]++ | |||||
| if err := pub.Verify(data, sig); err == nil { | |||||
| t.Errorf("publicKey.Verify on broken sig did not fail") | |||||
| } | |||||
| } | |||||
| } | |||||
| func TestParseRSAPrivateKey(t *testing.T) { | |||||
| key := testPrivateKeys["rsa"] | |||||
| rsa, ok := key.(*rsa.PrivateKey) | |||||
| if !ok { | |||||
| t.Fatalf("got %T, want *rsa.PrivateKey", rsa) | |||||
| } | |||||
| if err := rsa.Validate(); err != nil { | |||||
| t.Errorf("Validate: %v", err) | |||||
| } | |||||
| } | |||||
| func TestParseECPrivateKey(t *testing.T) { | |||||
| key := testPrivateKeys["ecdsa"] | |||||
| ecKey, ok := key.(*ecdsa.PrivateKey) | |||||
| if !ok { | |||||
| t.Fatalf("got %T, want *ecdsa.PrivateKey", ecKey) | |||||
| } | |||||
| if !validateECPublicKey(ecKey.Curve, ecKey.X, ecKey.Y) { | |||||
| t.Fatalf("public key does not validate.") | |||||
| } | |||||
| } | |||||
| func TestParseDSA(t *testing.T) { | |||||
| // We actually exercise the ParsePrivateKey codepath here, as opposed to | |||||
| // using the ParseRawPrivateKey+NewSignerFromKey path that testdata_test.go | |||||
| // uses. | |||||
| s, err := ParsePrivateKey(testdata.PEMBytes["dsa"]) | |||||
| if err != nil { | |||||
| t.Fatalf("ParsePrivateKey returned error: %s", err) | |||||
| } | |||||
| data := []byte("sign me") | |||||
| sig, err := s.Sign(rand.Reader, data) | |||||
| if err != nil { | |||||
| t.Fatalf("dsa.Sign: %v", err) | |||||
| } | |||||
| if err := s.PublicKey().Verify(data, sig); err != nil { | |||||
| t.Errorf("Verify failed: %v", err) | |||||
| } | |||||
| } | |||||
| // Tests for authorized_keys parsing. | |||||
| // getTestKey returns a public key, and its base64 encoding. | |||||
| func getTestKey() (PublicKey, string) { | |||||
| k := testPublicKeys["rsa"] | |||||
| b := &bytes.Buffer{} | |||||
| e := base64.NewEncoder(base64.StdEncoding, b) | |||||
| e.Write(k.Marshal()) | |||||
| e.Close() | |||||
| return k, b.String() | |||||
| } | |||||
| func TestMarshalParsePublicKey(t *testing.T) { | |||||
| pub, pubSerialized := getTestKey() | |||||
| line := fmt.Sprintf("%s %s user@host", pub.Type(), pubSerialized) | |||||
| authKeys := MarshalAuthorizedKey(pub) | |||||
| actualFields := strings.Fields(string(authKeys)) | |||||
| if len(actualFields) == 0 { | |||||
| t.Fatalf("failed authKeys: %v", authKeys) | |||||
| } | |||||
| // drop the comment | |||||
| expectedFields := strings.Fields(line)[0:2] | |||||
| if !reflect.DeepEqual(actualFields, expectedFields) { | |||||
| t.Errorf("got %v, expected %v", actualFields, expectedFields) | |||||
| } | |||||
| actPub, _, _, _, err := ParseAuthorizedKey([]byte(line)) | |||||
| if err != nil { | |||||
| t.Fatalf("cannot parse %v: %v", line, err) | |||||
| } | |||||
| if !reflect.DeepEqual(actPub, pub) { | |||||
| t.Errorf("got %v, expected %v", actPub, pub) | |||||
| } | |||||
| } | |||||
| type authResult struct { | |||||
| pubKey PublicKey | |||||
| options []string | |||||
| comments string | |||||
| rest string | |||||
| ok bool | |||||
| } | |||||
| func testAuthorizedKeys(t *testing.T, authKeys []byte, expected []authResult) { | |||||
| rest := authKeys | |||||
| var values []authResult | |||||
| for len(rest) > 0 { | |||||
| var r authResult | |||||
| var err error | |||||
| r.pubKey, r.comments, r.options, rest, err = ParseAuthorizedKey(rest) | |||||
| r.ok = (err == nil) | |||||
| t.Log(err) | |||||
| r.rest = string(rest) | |||||
| values = append(values, r) | |||||
| } | |||||
| if !reflect.DeepEqual(values, expected) { | |||||
| t.Errorf("got %#v, expected %#v", values, expected) | |||||
| } | |||||
| } | |||||
| func TestAuthorizedKeyBasic(t *testing.T) { | |||||
| pub, pubSerialized := getTestKey() | |||||
| line := "ssh-rsa " + pubSerialized + " user@host" | |||||
| testAuthorizedKeys(t, []byte(line), | |||||
| []authResult{ | |||||
| {pub, nil, "user@host", "", true}, | |||||
| }) | |||||
| } | |||||
| func TestAuth(t *testing.T) { | |||||
| pub, pubSerialized := getTestKey() | |||||
| authWithOptions := []string{ | |||||
| `# comments to ignore before any keys...`, | |||||
| ``, | |||||
| `env="HOME=/home/root",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`, | |||||
| `# comments to ignore, along with a blank line`, | |||||
| ``, | |||||
| `env="HOME=/home/root2" ssh-rsa ` + pubSerialized + ` user2@host2`, | |||||
| ``, | |||||
| `# more comments, plus a invalid entry`, | |||||
| `ssh-rsa data-that-will-not-parse user@host3`, | |||||
| } | |||||
| for _, eol := range []string{"\n", "\r\n"} { | |||||
| authOptions := strings.Join(authWithOptions, eol) | |||||
| rest2 := strings.Join(authWithOptions[3:], eol) | |||||
| rest3 := strings.Join(authWithOptions[6:], eol) | |||||
| testAuthorizedKeys(t, []byte(authOptions), []authResult{ | |||||
| {pub, []string{`env="HOME=/home/root"`, "no-port-forwarding"}, "user@host", rest2, true}, | |||||
| {pub, []string{`env="HOME=/home/root2"`}, "user2@host2", rest3, true}, | |||||
| {nil, nil, "", "", false}, | |||||
| }) | |||||
| } | |||||
| } | |||||
| func TestAuthWithQuotedSpaceInEnv(t *testing.T) { | |||||
| pub, pubSerialized := getTestKey() | |||||
| authWithQuotedSpaceInEnv := []byte(`env="HOME=/home/root dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) | |||||
| testAuthorizedKeys(t, []byte(authWithQuotedSpaceInEnv), []authResult{ | |||||
| {pub, []string{`env="HOME=/home/root dir"`, "no-port-forwarding"}, "user@host", "", true}, | |||||
| }) | |||||
| } | |||||
| func TestAuthWithQuotedCommaInEnv(t *testing.T) { | |||||
| pub, pubSerialized := getTestKey() | |||||
| authWithQuotedCommaInEnv := []byte(`env="HOME=/home/root,dir",no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host`) | |||||
| testAuthorizedKeys(t, []byte(authWithQuotedCommaInEnv), []authResult{ | |||||
| {pub, []string{`env="HOME=/home/root,dir"`, "no-port-forwarding"}, "user@host", "", true}, | |||||
| }) | |||||
| } | |||||
| func TestAuthWithQuotedQuoteInEnv(t *testing.T) { | |||||
| pub, pubSerialized := getTestKey() | |||||
| authWithQuotedQuoteInEnv := []byte(`env="HOME=/home/\"root dir",no-port-forwarding` + "\t" + `ssh-rsa` + "\t" + pubSerialized + ` user@host`) | |||||
| authWithDoubleQuotedQuote := []byte(`no-port-forwarding,env="HOME=/home/ \"root dir\"" ssh-rsa ` + pubSerialized + "\t" + `user@host`) | |||||
| testAuthorizedKeys(t, []byte(authWithQuotedQuoteInEnv), []authResult{ | |||||
| {pub, []string{`env="HOME=/home/\"root dir"`, "no-port-forwarding"}, "user@host", "", true}, | |||||
| }) | |||||
| testAuthorizedKeys(t, []byte(authWithDoubleQuotedQuote), []authResult{ | |||||
| {pub, []string{"no-port-forwarding", `env="HOME=/home/ \"root dir\""`}, "user@host", "", true}, | |||||
| }) | |||||
| } | |||||
| func TestAuthWithInvalidSpace(t *testing.T) { | |||||
| _, pubSerialized := getTestKey() | |||||
| authWithInvalidSpace := []byte(`env="HOME=/home/root dir", no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host | |||||
| #more to follow but still no valid keys`) | |||||
| testAuthorizedKeys(t, []byte(authWithInvalidSpace), []authResult{ | |||||
| {nil, nil, "", "", false}, | |||||
| }) | |||||
| } | |||||
| func TestAuthWithMissingQuote(t *testing.T) { | |||||
| pub, pubSerialized := getTestKey() | |||||
| authWithMissingQuote := []byte(`env="HOME=/home/root,no-port-forwarding ssh-rsa ` + pubSerialized + ` user@host | |||||
| env="HOME=/home/root",shared-control ssh-rsa ` + pubSerialized + ` user@host`) | |||||
| testAuthorizedKeys(t, []byte(authWithMissingQuote), []authResult{ | |||||
| {pub, []string{`env="HOME=/home/root"`, `shared-control`}, "user@host", "", true}, | |||||
| }) | |||||
| } | |||||
| func TestInvalidEntry(t *testing.T) { | |||||
| authInvalid := []byte(`ssh-rsa`) | |||||
| _, _, _, _, err := ParseAuthorizedKey(authInvalid) | |||||
| if err == nil { | |||||
| t.Errorf("got valid entry for %q", authInvalid) | |||||
| } | |||||
| } | |||||
| @@ -1,57 +0,0 @@ | |||||
| // Copyright 2012 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| // Message authentication support | |||||
| import ( | |||||
| "crypto/hmac" | |||||
| "crypto/sha1" | |||||
| "crypto/sha256" | |||||
| "hash" | |||||
| ) | |||||
| type macMode struct { | |||||
| keySize int | |||||
| new func(key []byte) hash.Hash | |||||
| } | |||||
| // truncatingMAC wraps around a hash.Hash and truncates the output digest to | |||||
| // a given size. | |||||
| type truncatingMAC struct { | |||||
| length int | |||||
| hmac hash.Hash | |||||
| } | |||||
| func (t truncatingMAC) Write(data []byte) (int, error) { | |||||
| return t.hmac.Write(data) | |||||
| } | |||||
| func (t truncatingMAC) Sum(in []byte) []byte { | |||||
| out := t.hmac.Sum(in) | |||||
| return out[:len(in)+t.length] | |||||
| } | |||||
| func (t truncatingMAC) Reset() { | |||||
| t.hmac.Reset() | |||||
| } | |||||
| func (t truncatingMAC) Size() int { | |||||
| return t.length | |||||
| } | |||||
| func (t truncatingMAC) BlockSize() int { return t.hmac.BlockSize() } | |||||
| var macModes = map[string]*macMode{ | |||||
| "hmac-sha2-256": {32, func(key []byte) hash.Hash { | |||||
| return hmac.New(sha256.New, key) | |||||
| }}, | |||||
| "hmac-sha1": {20, func(key []byte) hash.Hash { | |||||
| return hmac.New(sha1.New, key) | |||||
| }}, | |||||
| "hmac-sha1-96": {20, func(key []byte) hash.Hash { | |||||
| return truncatingMAC{12, hmac.New(sha1.New, key)} | |||||
| }}, | |||||
| } | |||||
| @@ -1,110 +0,0 @@ | |||||
| // Copyright 2013 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "io" | |||||
| "sync" | |||||
| "testing" | |||||
| ) | |||||
| // An in-memory packetConn. It is safe to call Close and writePacket | |||||
| // from different goroutines. | |||||
| type memTransport struct { | |||||
| eof bool | |||||
| pending [][]byte | |||||
| write *memTransport | |||||
| sync.Mutex | |||||
| *sync.Cond | |||||
| } | |||||
| func (t *memTransport) readPacket() ([]byte, error) { | |||||
| t.Lock() | |||||
| defer t.Unlock() | |||||
| for { | |||||
| if len(t.pending) > 0 { | |||||
| r := t.pending[0] | |||||
| t.pending = t.pending[1:] | |||||
| return r, nil | |||||
| } | |||||
| if t.eof { | |||||
| return nil, io.EOF | |||||
| } | |||||
| t.Cond.Wait() | |||||
| } | |||||
| } | |||||
| func (t *memTransport) closeSelf() error { | |||||
| t.Lock() | |||||
| defer t.Unlock() | |||||
| if t.eof { | |||||
| return io.EOF | |||||
| } | |||||
| t.eof = true | |||||
| t.Cond.Broadcast() | |||||
| return nil | |||||
| } | |||||
| func (t *memTransport) Close() error { | |||||
| err := t.write.closeSelf() | |||||
| t.closeSelf() | |||||
| return err | |||||
| } | |||||
| func (t *memTransport) writePacket(p []byte) error { | |||||
| t.write.Lock() | |||||
| defer t.write.Unlock() | |||||
| if t.write.eof { | |||||
| return io.EOF | |||||
| } | |||||
| c := make([]byte, len(p)) | |||||
| copy(c, p) | |||||
| t.write.pending = append(t.write.pending, c) | |||||
| t.write.Cond.Signal() | |||||
| return nil | |||||
| } | |||||
| func memPipe() (a, b packetConn) { | |||||
| t1 := memTransport{} | |||||
| t2 := memTransport{} | |||||
| t1.write = &t2 | |||||
| t2.write = &t1 | |||||
| t1.Cond = sync.NewCond(&t1.Mutex) | |||||
| t2.Cond = sync.NewCond(&t2.Mutex) | |||||
| return &t1, &t2 | |||||
| } | |||||
| func TestMemPipe(t *testing.T) { | |||||
| a, b := memPipe() | |||||
| if err := a.writePacket([]byte{42}); err != nil { | |||||
| t.Fatalf("writePacket: %v", err) | |||||
| } | |||||
| if err := a.Close(); err != nil { | |||||
| t.Fatal("Close: ", err) | |||||
| } | |||||
| p, err := b.readPacket() | |||||
| if err != nil { | |||||
| t.Fatal("readPacket: ", err) | |||||
| } | |||||
| if len(p) != 1 || p[0] != 42 { | |||||
| t.Fatalf("got %v, want {42}", p) | |||||
| } | |||||
| p, err = b.readPacket() | |||||
| if err != io.EOF { | |||||
| t.Fatalf("got %v, %v, want EOF", p, err) | |||||
| } | |||||
| } | |||||
| func TestDoubleClose(t *testing.T) { | |||||
| a, _ := memPipe() | |||||
| err := a.Close() | |||||
| if err != nil { | |||||
| t.Errorf("Close: %v", err) | |||||
| } | |||||
| err = a.Close() | |||||
| if err != io.EOF { | |||||
| t.Errorf("expect EOF on double close.") | |||||
| } | |||||
| } | |||||
| @@ -1,725 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "bytes" | |||||
| "encoding/binary" | |||||
| "errors" | |||||
| "fmt" | |||||
| "io" | |||||
| "math/big" | |||||
| "reflect" | |||||
| "strconv" | |||||
| ) | |||||
| // These are SSH message type numbers. They are scattered around several | |||||
| // documents but many were taken from [SSH-PARAMETERS]. | |||||
| const ( | |||||
| msgIgnore = 2 | |||||
| msgUnimplemented = 3 | |||||
| msgDebug = 4 | |||||
| msgNewKeys = 21 | |||||
| // Standard authentication messages | |||||
| msgUserAuthSuccess = 52 | |||||
| msgUserAuthBanner = 53 | |||||
| ) | |||||
| // SSH messages: | |||||
| // | |||||
| // These structures mirror the wire format of the corresponding SSH messages. | |||||
| // They are marshaled using reflection with the marshal and unmarshal functions | |||||
| // in this file. The only wrinkle is that a final member of type []byte with a | |||||
| // ssh tag of "rest" receives the remainder of a packet when unmarshaling. | |||||
| // See RFC 4253, section 11.1. | |||||
| const msgDisconnect = 1 | |||||
| // disconnectMsg is the message that signals a disconnect. It is also | |||||
| // the error type returned from mux.Wait() | |||||
| type disconnectMsg struct { | |||||
| Reason uint32 `sshtype:"1"` | |||||
| Message string | |||||
| Language string | |||||
| } | |||||
| func (d *disconnectMsg) Error() string { | |||||
| return fmt.Sprintf("ssh: disconnect reason %d: %s", d.Reason, d.Message) | |||||
| } | |||||
| // See RFC 4253, section 7.1. | |||||
| const msgKexInit = 20 | |||||
| type kexInitMsg struct { | |||||
| Cookie [16]byte `sshtype:"20"` | |||||
| KexAlgos []string | |||||
| ServerHostKeyAlgos []string | |||||
| CiphersClientServer []string | |||||
| CiphersServerClient []string | |||||
| MACsClientServer []string | |||||
| MACsServerClient []string | |||||
| CompressionClientServer []string | |||||
| CompressionServerClient []string | |||||
| LanguagesClientServer []string | |||||
| LanguagesServerClient []string | |||||
| FirstKexFollows bool | |||||
| Reserved uint32 | |||||
| } | |||||
| // See RFC 4253, section 8. | |||||
| // Diffie-Helman | |||||
| const msgKexDHInit = 30 | |||||
| type kexDHInitMsg struct { | |||||
| X *big.Int `sshtype:"30"` | |||||
| } | |||||
| const msgKexECDHInit = 30 | |||||
| type kexECDHInitMsg struct { | |||||
| ClientPubKey []byte `sshtype:"30"` | |||||
| } | |||||
| const msgKexECDHReply = 31 | |||||
| type kexECDHReplyMsg struct { | |||||
| HostKey []byte `sshtype:"31"` | |||||
| EphemeralPubKey []byte | |||||
| Signature []byte | |||||
| } | |||||
| const msgKexDHReply = 31 | |||||
| type kexDHReplyMsg struct { | |||||
| HostKey []byte `sshtype:"31"` | |||||
| Y *big.Int | |||||
| Signature []byte | |||||
| } | |||||
| // See RFC 4253, section 10. | |||||
| const msgServiceRequest = 5 | |||||
| type serviceRequestMsg struct { | |||||
| Service string `sshtype:"5"` | |||||
| } | |||||
| // See RFC 4253, section 10. | |||||
| const msgServiceAccept = 6 | |||||
| type serviceAcceptMsg struct { | |||||
| Service string `sshtype:"6"` | |||||
| } | |||||
| // See RFC 4252, section 5. | |||||
| const msgUserAuthRequest = 50 | |||||
| type userAuthRequestMsg struct { | |||||
| User string `sshtype:"50"` | |||||
| Service string | |||||
| Method string | |||||
| Payload []byte `ssh:"rest"` | |||||
| } | |||||
| // See RFC 4252, section 5.1 | |||||
| const msgUserAuthFailure = 51 | |||||
| type userAuthFailureMsg struct { | |||||
| Methods []string `sshtype:"51"` | |||||
| PartialSuccess bool | |||||
| } | |||||
| // See RFC 4256, section 3.2 | |||||
| const msgUserAuthInfoRequest = 60 | |||||
| const msgUserAuthInfoResponse = 61 | |||||
| type userAuthInfoRequestMsg struct { | |||||
| User string `sshtype:"60"` | |||||
| Instruction string | |||||
| DeprecatedLanguage string | |||||
| NumPrompts uint32 | |||||
| Prompts []byte `ssh:"rest"` | |||||
| } | |||||
| // See RFC 4254, section 5.1. | |||||
| const msgChannelOpen = 90 | |||||
| type channelOpenMsg struct { | |||||
| ChanType string `sshtype:"90"` | |||||
| PeersId uint32 | |||||
| PeersWindow uint32 | |||||
| MaxPacketSize uint32 | |||||
| TypeSpecificData []byte `ssh:"rest"` | |||||
| } | |||||
| const msgChannelExtendedData = 95 | |||||
| const msgChannelData = 94 | |||||
| // See RFC 4254, section 5.1. | |||||
| const msgChannelOpenConfirm = 91 | |||||
| type channelOpenConfirmMsg struct { | |||||
| PeersId uint32 `sshtype:"91"` | |||||
| MyId uint32 | |||||
| MyWindow uint32 | |||||
| MaxPacketSize uint32 | |||||
| TypeSpecificData []byte `ssh:"rest"` | |||||
| } | |||||
| // See RFC 4254, section 5.1. | |||||
| const msgChannelOpenFailure = 92 | |||||
| type channelOpenFailureMsg struct { | |||||
| PeersId uint32 `sshtype:"92"` | |||||
| Reason RejectionReason | |||||
| Message string | |||||
| Language string | |||||
| } | |||||
| const msgChannelRequest = 98 | |||||
| type channelRequestMsg struct { | |||||
| PeersId uint32 `sshtype:"98"` | |||||
| Request string | |||||
| WantReply bool | |||||
| RequestSpecificData []byte `ssh:"rest"` | |||||
| } | |||||
| // See RFC 4254, section 5.4. | |||||
| const msgChannelSuccess = 99 | |||||
| type channelRequestSuccessMsg struct { | |||||
| PeersId uint32 `sshtype:"99"` | |||||
| } | |||||
| // See RFC 4254, section 5.4. | |||||
| const msgChannelFailure = 100 | |||||
| type channelRequestFailureMsg struct { | |||||
| PeersId uint32 `sshtype:"100"` | |||||
| } | |||||
| // See RFC 4254, section 5.3 | |||||
| const msgChannelClose = 97 | |||||
| type channelCloseMsg struct { | |||||
| PeersId uint32 `sshtype:"97"` | |||||
| } | |||||
| // See RFC 4254, section 5.3 | |||||
| const msgChannelEOF = 96 | |||||
| type channelEOFMsg struct { | |||||
| PeersId uint32 `sshtype:"96"` | |||||
| } | |||||
| // See RFC 4254, section 4 | |||||
| const msgGlobalRequest = 80 | |||||
| type globalRequestMsg struct { | |||||
| Type string `sshtype:"80"` | |||||
| WantReply bool | |||||
| Data []byte `ssh:"rest"` | |||||
| } | |||||
| // See RFC 4254, section 4 | |||||
| const msgRequestSuccess = 81 | |||||
| type globalRequestSuccessMsg struct { | |||||
| Data []byte `ssh:"rest" sshtype:"81"` | |||||
| } | |||||
| // See RFC 4254, section 4 | |||||
| const msgRequestFailure = 82 | |||||
| type globalRequestFailureMsg struct { | |||||
| Data []byte `ssh:"rest" sshtype:"82"` | |||||
| } | |||||
| // See RFC 4254, section 5.2 | |||||
| const msgChannelWindowAdjust = 93 | |||||
| type windowAdjustMsg struct { | |||||
| PeersId uint32 `sshtype:"93"` | |||||
| AdditionalBytes uint32 | |||||
| } | |||||
| // See RFC 4252, section 7 | |||||
| const msgUserAuthPubKeyOk = 60 | |||||
| type userAuthPubKeyOkMsg struct { | |||||
| Algo string `sshtype:"60"` | |||||
| PubKey []byte | |||||
| } | |||||
| // typeTag returns the type byte for the given type. The type should | |||||
| // be struct. | |||||
| func typeTag(structType reflect.Type) byte { | |||||
| var tag byte | |||||
| var tagStr string | |||||
| tagStr = structType.Field(0).Tag.Get("sshtype") | |||||
| i, err := strconv.Atoi(tagStr) | |||||
| if err == nil { | |||||
| tag = byte(i) | |||||
| } | |||||
| return tag | |||||
| } | |||||
| func fieldError(t reflect.Type, field int, problem string) error { | |||||
| if problem != "" { | |||||
| problem = ": " + problem | |||||
| } | |||||
| return fmt.Errorf("ssh: unmarshal error for field %s of type %s%s", t.Field(field).Name, t.Name(), problem) | |||||
| } | |||||
| var errShortRead = errors.New("ssh: short read") | |||||
| // Unmarshal parses data in SSH wire format into a structure. The out | |||||
| // argument should be a pointer to struct. If the first member of the | |||||
| // struct has the "sshtype" tag set to a number in decimal, the packet | |||||
| // must start that number. In case of error, Unmarshal returns a | |||||
| // ParseError or UnexpectedMessageError. | |||||
| func Unmarshal(data []byte, out interface{}) error { | |||||
| v := reflect.ValueOf(out).Elem() | |||||
| structType := v.Type() | |||||
| expectedType := typeTag(structType) | |||||
| if len(data) == 0 { | |||||
| return parseError(expectedType) | |||||
| } | |||||
| if expectedType > 0 { | |||||
| if data[0] != expectedType { | |||||
| return unexpectedMessageError(expectedType, data[0]) | |||||
| } | |||||
| data = data[1:] | |||||
| } | |||||
| var ok bool | |||||
| for i := 0; i < v.NumField(); i++ { | |||||
| field := v.Field(i) | |||||
| t := field.Type() | |||||
| switch t.Kind() { | |||||
| case reflect.Bool: | |||||
| if len(data) < 1 { | |||||
| return errShortRead | |||||
| } | |||||
| field.SetBool(data[0] != 0) | |||||
| data = data[1:] | |||||
| case reflect.Array: | |||||
| if t.Elem().Kind() != reflect.Uint8 { | |||||
| return fieldError(structType, i, "array of unsupported type") | |||||
| } | |||||
| if len(data) < t.Len() { | |||||
| return errShortRead | |||||
| } | |||||
| for j, n := 0, t.Len(); j < n; j++ { | |||||
| field.Index(j).Set(reflect.ValueOf(data[j])) | |||||
| } | |||||
| data = data[t.Len():] | |||||
| case reflect.Uint64: | |||||
| var u64 uint64 | |||||
| if u64, data, ok = parseUint64(data); !ok { | |||||
| return errShortRead | |||||
| } | |||||
| field.SetUint(u64) | |||||
| case reflect.Uint32: | |||||
| var u32 uint32 | |||||
| if u32, data, ok = parseUint32(data); !ok { | |||||
| return errShortRead | |||||
| } | |||||
| field.SetUint(uint64(u32)) | |||||
| case reflect.Uint8: | |||||
| if len(data) < 1 { | |||||
| return errShortRead | |||||
| } | |||||
| field.SetUint(uint64(data[0])) | |||||
| data = data[1:] | |||||
| case reflect.String: | |||||
| var s []byte | |||||
| if s, data, ok = parseString(data); !ok { | |||||
| return fieldError(structType, i, "") | |||||
| } | |||||
| field.SetString(string(s)) | |||||
| case reflect.Slice: | |||||
| switch t.Elem().Kind() { | |||||
| case reflect.Uint8: | |||||
| if structType.Field(i).Tag.Get("ssh") == "rest" { | |||||
| field.Set(reflect.ValueOf(data)) | |||||
| data = nil | |||||
| } else { | |||||
| var s []byte | |||||
| if s, data, ok = parseString(data); !ok { | |||||
| return errShortRead | |||||
| } | |||||
| field.Set(reflect.ValueOf(s)) | |||||
| } | |||||
| case reflect.String: | |||||
| var nl []string | |||||
| if nl, data, ok = parseNameList(data); !ok { | |||||
| return errShortRead | |||||
| } | |||||
| field.Set(reflect.ValueOf(nl)) | |||||
| default: | |||||
| return fieldError(structType, i, "slice of unsupported type") | |||||
| } | |||||
| case reflect.Ptr: | |||||
| if t == bigIntType { | |||||
| var n *big.Int | |||||
| if n, data, ok = parseInt(data); !ok { | |||||
| return errShortRead | |||||
| } | |||||
| field.Set(reflect.ValueOf(n)) | |||||
| } else { | |||||
| return fieldError(structType, i, "pointer to unsupported type") | |||||
| } | |||||
| default: | |||||
| return fieldError(structType, i, "unsupported type") | |||||
| } | |||||
| } | |||||
| if len(data) != 0 { | |||||
| return parseError(expectedType) | |||||
| } | |||||
| return nil | |||||
| } | |||||
| // Marshal serializes the message in msg to SSH wire format. The msg | |||||
| // argument should be a struct or pointer to struct. If the first | |||||
| // member has the "sshtype" tag set to a number in decimal, that | |||||
| // number is prepended to the result. If the last of member has the | |||||
| // "ssh" tag set to "rest", its contents are appended to the output. | |||||
| func Marshal(msg interface{}) []byte { | |||||
| out := make([]byte, 0, 64) | |||||
| return marshalStruct(out, msg) | |||||
| } | |||||
| func marshalStruct(out []byte, msg interface{}) []byte { | |||||
| v := reflect.Indirect(reflect.ValueOf(msg)) | |||||
| msgType := typeTag(v.Type()) | |||||
| if msgType > 0 { | |||||
| out = append(out, msgType) | |||||
| } | |||||
| for i, n := 0, v.NumField(); i < n; i++ { | |||||
| field := v.Field(i) | |||||
| switch t := field.Type(); t.Kind() { | |||||
| case reflect.Bool: | |||||
| var v uint8 | |||||
| if field.Bool() { | |||||
| v = 1 | |||||
| } | |||||
| out = append(out, v) | |||||
| case reflect.Array: | |||||
| if t.Elem().Kind() != reflect.Uint8 { | |||||
| panic(fmt.Sprintf("array of non-uint8 in field %d: %T", i, field.Interface())) | |||||
| } | |||||
| for j, l := 0, t.Len(); j < l; j++ { | |||||
| out = append(out, uint8(field.Index(j).Uint())) | |||||
| } | |||||
| case reflect.Uint32: | |||||
| out = appendU32(out, uint32(field.Uint())) | |||||
| case reflect.Uint64: | |||||
| out = appendU64(out, uint64(field.Uint())) | |||||
| case reflect.Uint8: | |||||
| out = append(out, uint8(field.Uint())) | |||||
| case reflect.String: | |||||
| s := field.String() | |||||
| out = appendInt(out, len(s)) | |||||
| out = append(out, s...) | |||||
| case reflect.Slice: | |||||
| switch t.Elem().Kind() { | |||||
| case reflect.Uint8: | |||||
| if v.Type().Field(i).Tag.Get("ssh") != "rest" { | |||||
| out = appendInt(out, field.Len()) | |||||
| } | |||||
| out = append(out, field.Bytes()...) | |||||
| case reflect.String: | |||||
| offset := len(out) | |||||
| out = appendU32(out, 0) | |||||
| if n := field.Len(); n > 0 { | |||||
| for j := 0; j < n; j++ { | |||||
| f := field.Index(j) | |||||
| if j != 0 { | |||||
| out = append(out, ',') | |||||
| } | |||||
| out = append(out, f.String()...) | |||||
| } | |||||
| // overwrite length value | |||||
| binary.BigEndian.PutUint32(out[offset:], uint32(len(out)-offset-4)) | |||||
| } | |||||
| default: | |||||
| panic(fmt.Sprintf("slice of unknown type in field %d: %T", i, field.Interface())) | |||||
| } | |||||
| case reflect.Ptr: | |||||
| if t == bigIntType { | |||||
| var n *big.Int | |||||
| nValue := reflect.ValueOf(&n) | |||||
| nValue.Elem().Set(field) | |||||
| needed := intLength(n) | |||||
| oldLength := len(out) | |||||
| if cap(out)-len(out) < needed { | |||||
| newOut := make([]byte, len(out), 2*(len(out)+needed)) | |||||
| copy(newOut, out) | |||||
| out = newOut | |||||
| } | |||||
| out = out[:oldLength+needed] | |||||
| marshalInt(out[oldLength:], n) | |||||
| } else { | |||||
| panic(fmt.Sprintf("pointer to unknown type in field %d: %T", i, field.Interface())) | |||||
| } | |||||
| } | |||||
| } | |||||
| return out | |||||
| } | |||||
| var bigOne = big.NewInt(1) | |||||
| func parseString(in []byte) (out, rest []byte, ok bool) { | |||||
| if len(in) < 4 { | |||||
| return | |||||
| } | |||||
| length := binary.BigEndian.Uint32(in) | |||||
| in = in[4:] | |||||
| if uint32(len(in)) < length { | |||||
| return | |||||
| } | |||||
| out = in[:length] | |||||
| rest = in[length:] | |||||
| ok = true | |||||
| return | |||||
| } | |||||
| var ( | |||||
| comma = []byte{','} | |||||
| emptyNameList = []string{} | |||||
| ) | |||||
| func parseNameList(in []byte) (out []string, rest []byte, ok bool) { | |||||
| contents, rest, ok := parseString(in) | |||||
| if !ok { | |||||
| return | |||||
| } | |||||
| if len(contents) == 0 { | |||||
| out = emptyNameList | |||||
| return | |||||
| } | |||||
| parts := bytes.Split(contents, comma) | |||||
| out = make([]string, len(parts)) | |||||
| for i, part := range parts { | |||||
| out[i] = string(part) | |||||
| } | |||||
| return | |||||
| } | |||||
| func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) { | |||||
| contents, rest, ok := parseString(in) | |||||
| if !ok { | |||||
| return | |||||
| } | |||||
| out = new(big.Int) | |||||
| if len(contents) > 0 && contents[0]&0x80 == 0x80 { | |||||
| // This is a negative number | |||||
| notBytes := make([]byte, len(contents)) | |||||
| for i := range notBytes { | |||||
| notBytes[i] = ^contents[i] | |||||
| } | |||||
| out.SetBytes(notBytes) | |||||
| out.Add(out, bigOne) | |||||
| out.Neg(out) | |||||
| } else { | |||||
| // Positive number | |||||
| out.SetBytes(contents) | |||||
| } | |||||
| ok = true | |||||
| return | |||||
| } | |||||
| func parseUint32(in []byte) (uint32, []byte, bool) { | |||||
| if len(in) < 4 { | |||||
| return 0, nil, false | |||||
| } | |||||
| return binary.BigEndian.Uint32(in), in[4:], true | |||||
| } | |||||
| func parseUint64(in []byte) (uint64, []byte, bool) { | |||||
| if len(in) < 8 { | |||||
| return 0, nil, false | |||||
| } | |||||
| return binary.BigEndian.Uint64(in), in[8:], true | |||||
| } | |||||
| func intLength(n *big.Int) int { | |||||
| length := 4 /* length bytes */ | |||||
| if n.Sign() < 0 { | |||||
| nMinus1 := new(big.Int).Neg(n) | |||||
| nMinus1.Sub(nMinus1, bigOne) | |||||
| bitLen := nMinus1.BitLen() | |||||
| if bitLen%8 == 0 { | |||||
| // The number will need 0xff padding | |||||
| length++ | |||||
| } | |||||
| length += (bitLen + 7) / 8 | |||||
| } else if n.Sign() == 0 { | |||||
| // A zero is the zero length string | |||||
| } else { | |||||
| bitLen := n.BitLen() | |||||
| if bitLen%8 == 0 { | |||||
| // The number will need 0x00 padding | |||||
| length++ | |||||
| } | |||||
| length += (bitLen + 7) / 8 | |||||
| } | |||||
| return length | |||||
| } | |||||
| func marshalUint32(to []byte, n uint32) []byte { | |||||
| binary.BigEndian.PutUint32(to, n) | |||||
| return to[4:] | |||||
| } | |||||
| func marshalUint64(to []byte, n uint64) []byte { | |||||
| binary.BigEndian.PutUint64(to, n) | |||||
| return to[8:] | |||||
| } | |||||
| func marshalInt(to []byte, n *big.Int) []byte { | |||||
| lengthBytes := to | |||||
| to = to[4:] | |||||
| length := 0 | |||||
| if n.Sign() < 0 { | |||||
| // A negative number has to be converted to two's-complement | |||||
| // form. So we'll subtract 1 and invert. If the | |||||
| // most-significant-bit isn't set then we'll need to pad the | |||||
| // beginning with 0xff in order to keep the number negative. | |||||
| nMinus1 := new(big.Int).Neg(n) | |||||
| nMinus1.Sub(nMinus1, bigOne) | |||||
| bytes := nMinus1.Bytes() | |||||
| for i := range bytes { | |||||
| bytes[i] ^= 0xff | |||||
| } | |||||
| if len(bytes) == 0 || bytes[0]&0x80 == 0 { | |||||
| to[0] = 0xff | |||||
| to = to[1:] | |||||
| length++ | |||||
| } | |||||
| nBytes := copy(to, bytes) | |||||
| to = to[nBytes:] | |||||
| length += nBytes | |||||
| } else if n.Sign() == 0 { | |||||
| // A zero is the zero length string | |||||
| } else { | |||||
| bytes := n.Bytes() | |||||
| if len(bytes) > 0 && bytes[0]&0x80 != 0 { | |||||
| // We'll have to pad this with a 0x00 in order to | |||||
| // stop it looking like a negative number. | |||||
| to[0] = 0 | |||||
| to = to[1:] | |||||
| length++ | |||||
| } | |||||
| nBytes := copy(to, bytes) | |||||
| to = to[nBytes:] | |||||
| length += nBytes | |||||
| } | |||||
| lengthBytes[0] = byte(length >> 24) | |||||
| lengthBytes[1] = byte(length >> 16) | |||||
| lengthBytes[2] = byte(length >> 8) | |||||
| lengthBytes[3] = byte(length) | |||||
| return to | |||||
| } | |||||
| func writeInt(w io.Writer, n *big.Int) { | |||||
| length := intLength(n) | |||||
| buf := make([]byte, length) | |||||
| marshalInt(buf, n) | |||||
| w.Write(buf) | |||||
| } | |||||
| func writeString(w io.Writer, s []byte) { | |||||
| var lengthBytes [4]byte | |||||
| lengthBytes[0] = byte(len(s) >> 24) | |||||
| lengthBytes[1] = byte(len(s) >> 16) | |||||
| lengthBytes[2] = byte(len(s) >> 8) | |||||
| lengthBytes[3] = byte(len(s)) | |||||
| w.Write(lengthBytes[:]) | |||||
| w.Write(s) | |||||
| } | |||||
| func stringLength(n int) int { | |||||
| return 4 + n | |||||
| } | |||||
| func marshalString(to []byte, s []byte) []byte { | |||||
| to[0] = byte(len(s) >> 24) | |||||
| to[1] = byte(len(s) >> 16) | |||||
| to[2] = byte(len(s) >> 8) | |||||
| to[3] = byte(len(s)) | |||||
| to = to[4:] | |||||
| copy(to, s) | |||||
| return to[len(s):] | |||||
| } | |||||
| var bigIntType = reflect.TypeOf((*big.Int)(nil)) | |||||
| // Decode a packet into its corresponding message. | |||||
| func decode(packet []byte) (interface{}, error) { | |||||
| var msg interface{} | |||||
| switch packet[0] { | |||||
| case msgDisconnect: | |||||
| msg = new(disconnectMsg) | |||||
| case msgServiceRequest: | |||||
| msg = new(serviceRequestMsg) | |||||
| case msgServiceAccept: | |||||
| msg = new(serviceAcceptMsg) | |||||
| case msgKexInit: | |||||
| msg = new(kexInitMsg) | |||||
| case msgKexDHInit: | |||||
| msg = new(kexDHInitMsg) | |||||
| case msgKexDHReply: | |||||
| msg = new(kexDHReplyMsg) | |||||
| case msgUserAuthRequest: | |||||
| msg = new(userAuthRequestMsg) | |||||
| case msgUserAuthFailure: | |||||
| msg = new(userAuthFailureMsg) | |||||
| case msgUserAuthPubKeyOk: | |||||
| msg = new(userAuthPubKeyOkMsg) | |||||
| case msgGlobalRequest: | |||||
| msg = new(globalRequestMsg) | |||||
| case msgRequestSuccess: | |||||
| msg = new(globalRequestSuccessMsg) | |||||
| case msgRequestFailure: | |||||
| msg = new(globalRequestFailureMsg) | |||||
| case msgChannelOpen: | |||||
| msg = new(channelOpenMsg) | |||||
| case msgChannelOpenConfirm: | |||||
| msg = new(channelOpenConfirmMsg) | |||||
| case msgChannelOpenFailure: | |||||
| msg = new(channelOpenFailureMsg) | |||||
| case msgChannelWindowAdjust: | |||||
| msg = new(windowAdjustMsg) | |||||
| case msgChannelEOF: | |||||
| msg = new(channelEOFMsg) | |||||
| case msgChannelClose: | |||||
| msg = new(channelCloseMsg) | |||||
| case msgChannelRequest: | |||||
| msg = new(channelRequestMsg) | |||||
| case msgChannelSuccess: | |||||
| msg = new(channelRequestSuccessMsg) | |||||
| case msgChannelFailure: | |||||
| msg = new(channelRequestFailureMsg) | |||||
| default: | |||||
| return nil, unexpectedMessageError(0, packet[0]) | |||||
| } | |||||
| if err := Unmarshal(packet, msg); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return msg, nil | |||||
| } | |||||
| @@ -1,254 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "bytes" | |||||
| "math/big" | |||||
| "math/rand" | |||||
| "reflect" | |||||
| "testing" | |||||
| "testing/quick" | |||||
| ) | |||||
| var intLengthTests = []struct { | |||||
| val, length int | |||||
| }{ | |||||
| {0, 4 + 0}, | |||||
| {1, 4 + 1}, | |||||
| {127, 4 + 1}, | |||||
| {128, 4 + 2}, | |||||
| {-1, 4 + 1}, | |||||
| } | |||||
| func TestIntLength(t *testing.T) { | |||||
| for _, test := range intLengthTests { | |||||
| v := new(big.Int).SetInt64(int64(test.val)) | |||||
| length := intLength(v) | |||||
| if length != test.length { | |||||
| t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length) | |||||
| } | |||||
| } | |||||
| } | |||||
| type msgAllTypes struct { | |||||
| Bool bool `sshtype:"21"` | |||||
| Array [16]byte | |||||
| Uint64 uint64 | |||||
| Uint32 uint32 | |||||
| Uint8 uint8 | |||||
| String string | |||||
| Strings []string | |||||
| Bytes []byte | |||||
| Int *big.Int | |||||
| Rest []byte `ssh:"rest"` | |||||
| } | |||||
| func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value { | |||||
| m := &msgAllTypes{} | |||||
| m.Bool = rand.Intn(2) == 1 | |||||
| randomBytes(m.Array[:], rand) | |||||
| m.Uint64 = uint64(rand.Int63n(1<<63 - 1)) | |||||
| m.Uint32 = uint32(rand.Intn((1 << 31) - 1)) | |||||
| m.Uint8 = uint8(rand.Intn(1 << 8)) | |||||
| m.String = string(m.Array[:]) | |||||
| m.Strings = randomNameList(rand) | |||||
| m.Bytes = m.Array[:] | |||||
| m.Int = randomInt(rand) | |||||
| m.Rest = m.Array[:] | |||||
| return reflect.ValueOf(m) | |||||
| } | |||||
| func TestMarshalUnmarshal(t *testing.T) { | |||||
| rand := rand.New(rand.NewSource(0)) | |||||
| iface := &msgAllTypes{} | |||||
| ty := reflect.ValueOf(iface).Type() | |||||
| n := 100 | |||||
| if testing.Short() { | |||||
| n = 5 | |||||
| } | |||||
| for j := 0; j < n; j++ { | |||||
| v, ok := quick.Value(ty, rand) | |||||
| if !ok { | |||||
| t.Errorf("failed to create value") | |||||
| break | |||||
| } | |||||
| m1 := v.Elem().Interface() | |||||
| m2 := iface | |||||
| marshaled := Marshal(m1) | |||||
| if err := Unmarshal(marshaled, m2); err != nil { | |||||
| t.Errorf("Unmarshal %#v: %s", m1, err) | |||||
| break | |||||
| } | |||||
| if !reflect.DeepEqual(v.Interface(), m2) { | |||||
| t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled) | |||||
| break | |||||
| } | |||||
| } | |||||
| } | |||||
| func TestUnmarshalEmptyPacket(t *testing.T) { | |||||
| var b []byte | |||||
| var m channelRequestSuccessMsg | |||||
| if err := Unmarshal(b, &m); err == nil { | |||||
| t.Fatalf("unmarshal of empty slice succeeded") | |||||
| } | |||||
| } | |||||
| func TestUnmarshalUnexpectedPacket(t *testing.T) { | |||||
| type S struct { | |||||
| I uint32 `sshtype:"43"` | |||||
| S string | |||||
| B bool | |||||
| } | |||||
| s := S{11, "hello", true} | |||||
| packet := Marshal(s) | |||||
| packet[0] = 42 | |||||
| roundtrip := S{} | |||||
| err := Unmarshal(packet, &roundtrip) | |||||
| if err == nil { | |||||
| t.Fatal("expected error, not nil") | |||||
| } | |||||
| } | |||||
| func TestMarshalPtr(t *testing.T) { | |||||
| s := struct { | |||||
| S string | |||||
| }{"hello"} | |||||
| m1 := Marshal(s) | |||||
| m2 := Marshal(&s) | |||||
| if !bytes.Equal(m1, m2) { | |||||
| t.Errorf("got %q, want %q for marshaled pointer", m2, m1) | |||||
| } | |||||
| } | |||||
| func TestBareMarshalUnmarshal(t *testing.T) { | |||||
| type S struct { | |||||
| I uint32 | |||||
| S string | |||||
| B bool | |||||
| } | |||||
| s := S{42, "hello", true} | |||||
| packet := Marshal(s) | |||||
| roundtrip := S{} | |||||
| Unmarshal(packet, &roundtrip) | |||||
| if !reflect.DeepEqual(s, roundtrip) { | |||||
| t.Errorf("got %#v, want %#v", roundtrip, s) | |||||
| } | |||||
| } | |||||
| func TestBareMarshal(t *testing.T) { | |||||
| type S2 struct { | |||||
| I uint32 | |||||
| } | |||||
| s := S2{42} | |||||
| packet := Marshal(s) | |||||
| i, rest, ok := parseUint32(packet) | |||||
| if len(rest) > 0 || !ok { | |||||
| t.Errorf("parseInt(%q): parse error", packet) | |||||
| } | |||||
| if i != s.I { | |||||
| t.Errorf("got %d, want %d", i, s.I) | |||||
| } | |||||
| } | |||||
| func TestUnmarshalShortKexInitPacket(t *testing.T) { | |||||
| // This used to panic. | |||||
| // Issue 11348 | |||||
| packet := []byte{0x14, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0xff, 0xff, 0xff, 0xff} | |||||
| kim := &kexInitMsg{} | |||||
| if err := Unmarshal(packet, kim); err == nil { | |||||
| t.Error("truncated packet unmarshaled without error") | |||||
| } | |||||
| } | |||||
| func randomBytes(out []byte, rand *rand.Rand) { | |||||
| for i := 0; i < len(out); i++ { | |||||
| out[i] = byte(rand.Int31()) | |||||
| } | |||||
| } | |||||
| func randomNameList(rand *rand.Rand) []string { | |||||
| ret := make([]string, rand.Int31()&15) | |||||
| for i := range ret { | |||||
| s := make([]byte, 1+(rand.Int31()&15)) | |||||
| for j := range s { | |||||
| s[j] = 'a' + uint8(rand.Int31()&15) | |||||
| } | |||||
| ret[i] = string(s) | |||||
| } | |||||
| return ret | |||||
| } | |||||
| func randomInt(rand *rand.Rand) *big.Int { | |||||
| return new(big.Int).SetInt64(int64(int32(rand.Uint32()))) | |||||
| } | |||||
| func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { | |||||
| ki := &kexInitMsg{} | |||||
| randomBytes(ki.Cookie[:], rand) | |||||
| ki.KexAlgos = randomNameList(rand) | |||||
| ki.ServerHostKeyAlgos = randomNameList(rand) | |||||
| ki.CiphersClientServer = randomNameList(rand) | |||||
| ki.CiphersServerClient = randomNameList(rand) | |||||
| ki.MACsClientServer = randomNameList(rand) | |||||
| ki.MACsServerClient = randomNameList(rand) | |||||
| ki.CompressionClientServer = randomNameList(rand) | |||||
| ki.CompressionServerClient = randomNameList(rand) | |||||
| ki.LanguagesClientServer = randomNameList(rand) | |||||
| ki.LanguagesServerClient = randomNameList(rand) | |||||
| if rand.Int31()&1 == 1 { | |||||
| ki.FirstKexFollows = true | |||||
| } | |||||
| return reflect.ValueOf(ki) | |||||
| } | |||||
| func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value { | |||||
| dhi := &kexDHInitMsg{} | |||||
| dhi.X = randomInt(rand) | |||||
| return reflect.ValueOf(dhi) | |||||
| } | |||||
| var ( | |||||
| _kexInitMsg = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() | |||||
| _kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface() | |||||
| _kexInit = Marshal(_kexInitMsg) | |||||
| _kexDHInit = Marshal(_kexDHInitMsg) | |||||
| ) | |||||
| func BenchmarkMarshalKexInitMsg(b *testing.B) { | |||||
| for i := 0; i < b.N; i++ { | |||||
| Marshal(_kexInitMsg) | |||||
| } | |||||
| } | |||||
| func BenchmarkUnmarshalKexInitMsg(b *testing.B) { | |||||
| m := new(kexInitMsg) | |||||
| for i := 0; i < b.N; i++ { | |||||
| Unmarshal(_kexInit, m) | |||||
| } | |||||
| } | |||||
| func BenchmarkMarshalKexDHInitMsg(b *testing.B) { | |||||
| for i := 0; i < b.N; i++ { | |||||
| Marshal(_kexDHInitMsg) | |||||
| } | |||||
| } | |||||
| func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) { | |||||
| m := new(kexDHInitMsg) | |||||
| for i := 0; i < b.N; i++ { | |||||
| Unmarshal(_kexDHInit, m) | |||||
| } | |||||
| } | |||||
| @@ -1,356 +0,0 @@ | |||||
| // Copyright 2013 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "encoding/binary" | |||||
| "fmt" | |||||
| "io" | |||||
| "log" | |||||
| "sync" | |||||
| "sync/atomic" | |||||
| ) | |||||
| // debugMux, if set, causes messages in the connection protocol to be | |||||
| // logged. | |||||
| const debugMux = false | |||||
| // chanList is a thread safe channel list. | |||||
| type chanList struct { | |||||
| // protects concurrent access to chans | |||||
| sync.Mutex | |||||
| // chans are indexed by the local id of the channel, which the | |||||
| // other side should send in the PeersId field. | |||||
| chans []*channel | |||||
| // This is a debugging aid: it offsets all IDs by this | |||||
| // amount. This helps distinguish otherwise identical | |||||
| // server/client muxes | |||||
| offset uint32 | |||||
| } | |||||
| // Assigns a channel ID to the given channel. | |||||
| func (c *chanList) add(ch *channel) uint32 { | |||||
| c.Lock() | |||||
| defer c.Unlock() | |||||
| for i := range c.chans { | |||||
| if c.chans[i] == nil { | |||||
| c.chans[i] = ch | |||||
| return uint32(i) + c.offset | |||||
| } | |||||
| } | |||||
| c.chans = append(c.chans, ch) | |||||
| return uint32(len(c.chans)-1) + c.offset | |||||
| } | |||||
| // getChan returns the channel for the given ID. | |||||
| func (c *chanList) getChan(id uint32) *channel { | |||||
| id -= c.offset | |||||
| c.Lock() | |||||
| defer c.Unlock() | |||||
| if id < uint32(len(c.chans)) { | |||||
| return c.chans[id] | |||||
| } | |||||
| return nil | |||||
| } | |||||
| func (c *chanList) remove(id uint32) { | |||||
| id -= c.offset | |||||
| c.Lock() | |||||
| if id < uint32(len(c.chans)) { | |||||
| c.chans[id] = nil | |||||
| } | |||||
| c.Unlock() | |||||
| } | |||||
| // dropAll forgets all channels it knows, returning them in a slice. | |||||
| func (c *chanList) dropAll() []*channel { | |||||
| c.Lock() | |||||
| defer c.Unlock() | |||||
| var r []*channel | |||||
| for _, ch := range c.chans { | |||||
| if ch == nil { | |||||
| continue | |||||
| } | |||||
| r = append(r, ch) | |||||
| } | |||||
| c.chans = nil | |||||
| return r | |||||
| } | |||||
| // mux represents the state for the SSH connection protocol, which | |||||
| // multiplexes many channels onto a single packet transport. | |||||
| type mux struct { | |||||
| conn packetConn | |||||
| chanList chanList | |||||
| incomingChannels chan NewChannel | |||||
| globalSentMu sync.Mutex | |||||
| globalResponses chan interface{} | |||||
| incomingRequests chan *Request | |||||
| errCond *sync.Cond | |||||
| err error | |||||
| } | |||||
| // When debugging, each new chanList instantiation has a different | |||||
| // offset. | |||||
| var globalOff uint32 | |||||
| func (m *mux) Wait() error { | |||||
| m.errCond.L.Lock() | |||||
| defer m.errCond.L.Unlock() | |||||
| for m.err == nil { | |||||
| m.errCond.Wait() | |||||
| } | |||||
| return m.err | |||||
| } | |||||
| // newMux returns a mux that runs over the given connection. | |||||
| func newMux(p packetConn) *mux { | |||||
| m := &mux{ | |||||
| conn: p, | |||||
| incomingChannels: make(chan NewChannel, 16), | |||||
| globalResponses: make(chan interface{}, 1), | |||||
| incomingRequests: make(chan *Request, 16), | |||||
| errCond: newCond(), | |||||
| } | |||||
| if debugMux { | |||||
| m.chanList.offset = atomic.AddUint32(&globalOff, 1) | |||||
| } | |||||
| go m.loop() | |||||
| return m | |||||
| } | |||||
| func (m *mux) sendMessage(msg interface{}) error { | |||||
| p := Marshal(msg) | |||||
| return m.conn.writePacket(p) | |||||
| } | |||||
| func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { | |||||
| if wantReply { | |||||
| m.globalSentMu.Lock() | |||||
| defer m.globalSentMu.Unlock() | |||||
| } | |||||
| if err := m.sendMessage(globalRequestMsg{ | |||||
| Type: name, | |||||
| WantReply: wantReply, | |||||
| Data: payload, | |||||
| }); err != nil { | |||||
| return false, nil, err | |||||
| } | |||||
| if !wantReply { | |||||
| return false, nil, nil | |||||
| } | |||||
| msg, ok := <-m.globalResponses | |||||
| if !ok { | |||||
| return false, nil, io.EOF | |||||
| } | |||||
| switch msg := msg.(type) { | |||||
| case *globalRequestFailureMsg: | |||||
| return false, msg.Data, nil | |||||
| case *globalRequestSuccessMsg: | |||||
| return true, msg.Data, nil | |||||
| default: | |||||
| return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) | |||||
| } | |||||
| } | |||||
| // ackRequest must be called after processing a global request that | |||||
| // has WantReply set. | |||||
| func (m *mux) ackRequest(ok bool, data []byte) error { | |||||
| if ok { | |||||
| return m.sendMessage(globalRequestSuccessMsg{Data: data}) | |||||
| } | |||||
| return m.sendMessage(globalRequestFailureMsg{Data: data}) | |||||
| } | |||||
| // TODO(hanwen): Disconnect is a transport layer message. We should | |||||
| // probably send and receive Disconnect somewhere in the transport | |||||
| // code. | |||||
| // Disconnect sends a disconnect message. | |||||
| func (m *mux) Disconnect(reason uint32, message string) error { | |||||
| return m.sendMessage(disconnectMsg{ | |||||
| Reason: reason, | |||||
| Message: message, | |||||
| }) | |||||
| } | |||||
| func (m *mux) Close() error { | |||||
| return m.conn.Close() | |||||
| } | |||||
| // loop runs the connection machine. It will process packets until an | |||||
| // error is encountered. To synchronize on loop exit, use mux.Wait. | |||||
| func (m *mux) loop() { | |||||
| var err error | |||||
| for err == nil { | |||||
| err = m.onePacket() | |||||
| } | |||||
| for _, ch := range m.chanList.dropAll() { | |||||
| ch.close() | |||||
| } | |||||
| close(m.incomingChannels) | |||||
| close(m.incomingRequests) | |||||
| close(m.globalResponses) | |||||
| m.conn.Close() | |||||
| m.errCond.L.Lock() | |||||
| m.err = err | |||||
| m.errCond.Broadcast() | |||||
| m.errCond.L.Unlock() | |||||
| if debugMux { | |||||
| log.Println("loop exit", err) | |||||
| } | |||||
| } | |||||
| // onePacket reads and processes one packet. | |||||
| func (m *mux) onePacket() error { | |||||
| packet, err := m.conn.readPacket() | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| if debugMux { | |||||
| if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { | |||||
| log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) | |||||
| } else { | |||||
| p, _ := decode(packet) | |||||
| log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) | |||||
| } | |||||
| } | |||||
| switch packet[0] { | |||||
| case msgNewKeys: | |||||
| // Ignore notification of key change. | |||||
| return nil | |||||
| case msgDisconnect: | |||||
| return m.handleDisconnect(packet) | |||||
| case msgChannelOpen: | |||||
| return m.handleChannelOpen(packet) | |||||
| case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: | |||||
| return m.handleGlobalPacket(packet) | |||||
| } | |||||
| // assume a channel packet. | |||||
| if len(packet) < 5 { | |||||
| return parseError(packet[0]) | |||||
| } | |||||
| id := binary.BigEndian.Uint32(packet[1:]) | |||||
| ch := m.chanList.getChan(id) | |||||
| if ch == nil { | |||||
| return fmt.Errorf("ssh: invalid channel %d", id) | |||||
| } | |||||
| return ch.handlePacket(packet) | |||||
| } | |||||
| func (m *mux) handleDisconnect(packet []byte) error { | |||||
| var d disconnectMsg | |||||
| if err := Unmarshal(packet, &d); err != nil { | |||||
| return err | |||||
| } | |||||
| if debugMux { | |||||
| log.Printf("caught disconnect: %v", d) | |||||
| } | |||||
| return &d | |||||
| } | |||||
| func (m *mux) handleGlobalPacket(packet []byte) error { | |||||
| msg, err := decode(packet) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| switch msg := msg.(type) { | |||||
| case *globalRequestMsg: | |||||
| m.incomingRequests <- &Request{ | |||||
| Type: msg.Type, | |||||
| WantReply: msg.WantReply, | |||||
| Payload: msg.Data, | |||||
| mux: m, | |||||
| } | |||||
| case *globalRequestSuccessMsg, *globalRequestFailureMsg: | |||||
| m.globalResponses <- msg | |||||
| default: | |||||
| panic(fmt.Sprintf("not a global message %#v", msg)) | |||||
| } | |||||
| return nil | |||||
| } | |||||
| // handleChannelOpen schedules a channel to be Accept()ed. | |||||
| func (m *mux) handleChannelOpen(packet []byte) error { | |||||
| var msg channelOpenMsg | |||||
| if err := Unmarshal(packet, &msg); err != nil { | |||||
| return err | |||||
| } | |||||
| if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { | |||||
| failMsg := channelOpenFailureMsg{ | |||||
| PeersId: msg.PeersId, | |||||
| Reason: ConnectionFailed, | |||||
| Message: "invalid request", | |||||
| Language: "en_US.UTF-8", | |||||
| } | |||||
| return m.sendMessage(failMsg) | |||||
| } | |||||
| c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) | |||||
| c.remoteId = msg.PeersId | |||||
| c.maxRemotePayload = msg.MaxPacketSize | |||||
| c.remoteWin.add(msg.PeersWindow) | |||||
| m.incomingChannels <- c | |||||
| return nil | |||||
| } | |||||
| func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) { | |||||
| ch, err := m.openChannel(chanType, extra) | |||||
| if err != nil { | |||||
| return nil, nil, err | |||||
| } | |||||
| return ch, ch.incomingRequests, nil | |||||
| } | |||||
| func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) { | |||||
| ch := m.newChannel(chanType, channelOutbound, extra) | |||||
| ch.maxIncomingPayload = channelMaxPacket | |||||
| open := channelOpenMsg{ | |||||
| ChanType: chanType, | |||||
| PeersWindow: ch.myWindow, | |||||
| MaxPacketSize: ch.maxIncomingPayload, | |||||
| TypeSpecificData: extra, | |||||
| PeersId: ch.localId, | |||||
| } | |||||
| if err := m.sendMessage(open); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| switch msg := (<-ch.msg).(type) { | |||||
| case *channelOpenConfirmMsg: | |||||
| return ch, nil | |||||
| case *channelOpenFailureMsg: | |||||
| return nil, &OpenChannelError{msg.Reason, msg.Message} | |||||
| default: | |||||
| return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg) | |||||
| } | |||||
| } | |||||
| @@ -1,525 +0,0 @@ | |||||
| // Copyright 2013 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "io" | |||||
| "io/ioutil" | |||||
| "sync" | |||||
| "testing" | |||||
| ) | |||||
| func muxPair() (*mux, *mux) { | |||||
| a, b := memPipe() | |||||
| s := newMux(a) | |||||
| c := newMux(b) | |||||
| return s, c | |||||
| } | |||||
| // Returns both ends of a channel, and the mux for the the 2nd | |||||
| // channel. | |||||
| func channelPair(t *testing.T) (*channel, *channel, *mux) { | |||||
| c, s := muxPair() | |||||
| res := make(chan *channel, 1) | |||||
| go func() { | |||||
| newCh, ok := <-s.incomingChannels | |||||
| if !ok { | |||||
| t.Fatalf("No incoming channel") | |||||
| } | |||||
| if newCh.ChannelType() != "chan" { | |||||
| t.Fatalf("got type %q want chan", newCh.ChannelType()) | |||||
| } | |||||
| ch, _, err := newCh.Accept() | |||||
| if err != nil { | |||||
| t.Fatalf("Accept %v", err) | |||||
| } | |||||
| res <- ch.(*channel) | |||||
| }() | |||||
| ch, err := c.openChannel("chan", nil) | |||||
| if err != nil { | |||||
| t.Fatalf("OpenChannel: %v", err) | |||||
| } | |||||
| return <-res, ch, c | |||||
| } | |||||
| // Test that stderr and stdout can be addressed from different | |||||
| // goroutines. This is intended for use with the race detector. | |||||
| func TestMuxChannelExtendedThreadSafety(t *testing.T) { | |||||
| writer, reader, mux := channelPair(t) | |||||
| defer writer.Close() | |||||
| defer reader.Close() | |||||
| defer mux.Close() | |||||
| var wr, rd sync.WaitGroup | |||||
| magic := "hello world" | |||||
| wr.Add(2) | |||||
| go func() { | |||||
| io.WriteString(writer, magic) | |||||
| wr.Done() | |||||
| }() | |||||
| go func() { | |||||
| io.WriteString(writer.Stderr(), magic) | |||||
| wr.Done() | |||||
| }() | |||||
| rd.Add(2) | |||||
| go func() { | |||||
| c, err := ioutil.ReadAll(reader) | |||||
| if string(c) != magic { | |||||
| t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err) | |||||
| } | |||||
| rd.Done() | |||||
| }() | |||||
| go func() { | |||||
| c, err := ioutil.ReadAll(reader.Stderr()) | |||||
| if string(c) != magic { | |||||
| t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err) | |||||
| } | |||||
| rd.Done() | |||||
| }() | |||||
| wr.Wait() | |||||
| writer.CloseWrite() | |||||
| rd.Wait() | |||||
| } | |||||
| func TestMuxReadWrite(t *testing.T) { | |||||
| s, c, mux := channelPair(t) | |||||
| defer s.Close() | |||||
| defer c.Close() | |||||
| defer mux.Close() | |||||
| magic := "hello world" | |||||
| magicExt := "hello stderr" | |||||
| go func() { | |||||
| _, err := s.Write([]byte(magic)) | |||||
| if err != nil { | |||||
| t.Fatalf("Write: %v", err) | |||||
| } | |||||
| _, err = s.Extended(1).Write([]byte(magicExt)) | |||||
| if err != nil { | |||||
| t.Fatalf("Write: %v", err) | |||||
| } | |||||
| err = s.Close() | |||||
| if err != nil { | |||||
| t.Fatalf("Close: %v", err) | |||||
| } | |||||
| }() | |||||
| var buf [1024]byte | |||||
| n, err := c.Read(buf[:]) | |||||
| if err != nil { | |||||
| t.Fatalf("server Read: %v", err) | |||||
| } | |||||
| got := string(buf[:n]) | |||||
| if got != magic { | |||||
| t.Fatalf("server: got %q want %q", got, magic) | |||||
| } | |||||
| n, err = c.Extended(1).Read(buf[:]) | |||||
| if err != nil { | |||||
| t.Fatalf("server Read: %v", err) | |||||
| } | |||||
| got = string(buf[:n]) | |||||
| if got != magicExt { | |||||
| t.Fatalf("server: got %q want %q", got, magic) | |||||
| } | |||||
| } | |||||
| func TestMuxChannelOverflow(t *testing.T) { | |||||
| reader, writer, mux := channelPair(t) | |||||
| defer reader.Close() | |||||
| defer writer.Close() | |||||
| defer mux.Close() | |||||
| wDone := make(chan int, 1) | |||||
| go func() { | |||||
| if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { | |||||
| t.Errorf("could not fill window: %v", err) | |||||
| } | |||||
| writer.Write(make([]byte, 1)) | |||||
| wDone <- 1 | |||||
| }() | |||||
| writer.remoteWin.waitWriterBlocked() | |||||
| // Send 1 byte. | |||||
| packet := make([]byte, 1+4+4+1) | |||||
| packet[0] = msgChannelData | |||||
| marshalUint32(packet[1:], writer.remoteId) | |||||
| marshalUint32(packet[5:], uint32(1)) | |||||
| packet[9] = 42 | |||||
| if err := writer.mux.conn.writePacket(packet); err != nil { | |||||
| t.Errorf("could not send packet") | |||||
| } | |||||
| if _, err := reader.SendRequest("hello", true, nil); err == nil { | |||||
| t.Errorf("SendRequest succeeded.") | |||||
| } | |||||
| <-wDone | |||||
| } | |||||
| func TestMuxChannelCloseWriteUnblock(t *testing.T) { | |||||
| reader, writer, mux := channelPair(t) | |||||
| defer reader.Close() | |||||
| defer writer.Close() | |||||
| defer mux.Close() | |||||
| wDone := make(chan int, 1) | |||||
| go func() { | |||||
| if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { | |||||
| t.Errorf("could not fill window: %v", err) | |||||
| } | |||||
| if _, err := writer.Write(make([]byte, 1)); err != io.EOF { | |||||
| t.Errorf("got %v, want EOF for unblock write", err) | |||||
| } | |||||
| wDone <- 1 | |||||
| }() | |||||
| writer.remoteWin.waitWriterBlocked() | |||||
| reader.Close() | |||||
| <-wDone | |||||
| } | |||||
| func TestMuxConnectionCloseWriteUnblock(t *testing.T) { | |||||
| reader, writer, mux := channelPair(t) | |||||
| defer reader.Close() | |||||
| defer writer.Close() | |||||
| defer mux.Close() | |||||
| wDone := make(chan int, 1) | |||||
| go func() { | |||||
| if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil { | |||||
| t.Errorf("could not fill window: %v", err) | |||||
| } | |||||
| if _, err := writer.Write(make([]byte, 1)); err != io.EOF { | |||||
| t.Errorf("got %v, want EOF for unblock write", err) | |||||
| } | |||||
| wDone <- 1 | |||||
| }() | |||||
| writer.remoteWin.waitWriterBlocked() | |||||
| mux.Close() | |||||
| <-wDone | |||||
| } | |||||
| func TestMuxReject(t *testing.T) { | |||||
| client, server := muxPair() | |||||
| defer server.Close() | |||||
| defer client.Close() | |||||
| go func() { | |||||
| ch, ok := <-server.incomingChannels | |||||
| if !ok { | |||||
| t.Fatalf("Accept") | |||||
| } | |||||
| if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { | |||||
| t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) | |||||
| } | |||||
| ch.Reject(RejectionReason(42), "message") | |||||
| }() | |||||
| ch, err := client.openChannel("ch", []byte("extra")) | |||||
| if ch != nil { | |||||
| t.Fatal("openChannel not rejected") | |||||
| } | |||||
| ocf, ok := err.(*OpenChannelError) | |||||
| if !ok { | |||||
| t.Errorf("got %#v want *OpenChannelError", err) | |||||
| } else if ocf.Reason != 42 || ocf.Message != "message" { | |||||
| t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message") | |||||
| } | |||||
| want := "ssh: rejected: unknown reason 42 (message)" | |||||
| if err.Error() != want { | |||||
| t.Errorf("got %q, want %q", err.Error(), want) | |||||
| } | |||||
| } | |||||
| func TestMuxChannelRequest(t *testing.T) { | |||||
| client, server, mux := channelPair(t) | |||||
| defer server.Close() | |||||
| defer client.Close() | |||||
| defer mux.Close() | |||||
| var received int | |||||
| var wg sync.WaitGroup | |||||
| wg.Add(1) | |||||
| go func() { | |||||
| for r := range server.incomingRequests { | |||||
| received++ | |||||
| r.Reply(r.Type == "yes", nil) | |||||
| } | |||||
| wg.Done() | |||||
| }() | |||||
| _, err := client.SendRequest("yes", false, nil) | |||||
| if err != nil { | |||||
| t.Fatalf("SendRequest: %v", err) | |||||
| } | |||||
| ok, err := client.SendRequest("yes", true, nil) | |||||
| if err != nil { | |||||
| t.Fatalf("SendRequest: %v", err) | |||||
| } | |||||
| if !ok { | |||||
| t.Errorf("SendRequest(yes): %v", ok) | |||||
| } | |||||
| ok, err = client.SendRequest("no", true, nil) | |||||
| if err != nil { | |||||
| t.Fatalf("SendRequest: %v", err) | |||||
| } | |||||
| if ok { | |||||
| t.Errorf("SendRequest(no): %v", ok) | |||||
| } | |||||
| client.Close() | |||||
| wg.Wait() | |||||
| if received != 3 { | |||||
| t.Errorf("got %d requests, want %d", received, 3) | |||||
| } | |||||
| } | |||||
| func TestMuxGlobalRequest(t *testing.T) { | |||||
| clientMux, serverMux := muxPair() | |||||
| defer serverMux.Close() | |||||
| defer clientMux.Close() | |||||
| var seen bool | |||||
| go func() { | |||||
| for r := range serverMux.incomingRequests { | |||||
| seen = seen || r.Type == "peek" | |||||
| if r.WantReply { | |||||
| err := r.Reply(r.Type == "yes", | |||||
| append([]byte(r.Type), r.Payload...)) | |||||
| if err != nil { | |||||
| t.Errorf("AckRequest: %v", err) | |||||
| } | |||||
| } | |||||
| } | |||||
| }() | |||||
| _, _, err := clientMux.SendRequest("peek", false, nil) | |||||
| if err != nil { | |||||
| t.Errorf("SendRequest: %v", err) | |||||
| } | |||||
| ok, data, err := clientMux.SendRequest("yes", true, []byte("a")) | |||||
| if !ok || string(data) != "yesa" || err != nil { | |||||
| t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", | |||||
| ok, data, err) | |||||
| } | |||||
| if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil { | |||||
| t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", | |||||
| ok, data, err) | |||||
| } | |||||
| if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil { | |||||
| t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", | |||||
| ok, data, err) | |||||
| } | |||||
| clientMux.Disconnect(0, "") | |||||
| if !seen { | |||||
| t.Errorf("never saw 'peek' request") | |||||
| } | |||||
| } | |||||
| func TestMuxGlobalRequestUnblock(t *testing.T) { | |||||
| clientMux, serverMux := muxPair() | |||||
| defer serverMux.Close() | |||||
| defer clientMux.Close() | |||||
| result := make(chan error, 1) | |||||
| go func() { | |||||
| _, _, err := clientMux.SendRequest("hello", true, nil) | |||||
| result <- err | |||||
| }() | |||||
| <-serverMux.incomingRequests | |||||
| serverMux.conn.Close() | |||||
| err := <-result | |||||
| if err != io.EOF { | |||||
| t.Errorf("want EOF, got %v", io.EOF) | |||||
| } | |||||
| } | |||||
| func TestMuxChannelRequestUnblock(t *testing.T) { | |||||
| a, b, connB := channelPair(t) | |||||
| defer a.Close() | |||||
| defer b.Close() | |||||
| defer connB.Close() | |||||
| result := make(chan error, 1) | |||||
| go func() { | |||||
| _, err := a.SendRequest("hello", true, nil) | |||||
| result <- err | |||||
| }() | |||||
| <-b.incomingRequests | |||||
| connB.conn.Close() | |||||
| err := <-result | |||||
| if err != io.EOF { | |||||
| t.Errorf("want EOF, got %v", err) | |||||
| } | |||||
| } | |||||
| func TestMuxDisconnect(t *testing.T) { | |||||
| a, b := muxPair() | |||||
| defer a.Close() | |||||
| defer b.Close() | |||||
| go func() { | |||||
| for r := range b.incomingRequests { | |||||
| r.Reply(true, nil) | |||||
| } | |||||
| }() | |||||
| a.Disconnect(42, "whatever") | |||||
| ok, _, err := a.SendRequest("hello", true, nil) | |||||
| if ok || err == nil { | |||||
| t.Errorf("got reply after disconnecting") | |||||
| } | |||||
| err = b.Wait() | |||||
| if d, ok := err.(*disconnectMsg); !ok || d.Reason != 42 { | |||||
| t.Errorf("got %#v, want disconnectMsg{Reason:42}", err) | |||||
| } | |||||
| } | |||||
| func TestMuxCloseChannel(t *testing.T) { | |||||
| r, w, mux := channelPair(t) | |||||
| defer mux.Close() | |||||
| defer r.Close() | |||||
| defer w.Close() | |||||
| result := make(chan error, 1) | |||||
| go func() { | |||||
| var b [1024]byte | |||||
| _, err := r.Read(b[:]) | |||||
| result <- err | |||||
| }() | |||||
| if err := w.Close(); err != nil { | |||||
| t.Errorf("w.Close: %v", err) | |||||
| } | |||||
| if _, err := w.Write([]byte("hello")); err != io.EOF { | |||||
| t.Errorf("got err %v, want io.EOF after Close", err) | |||||
| } | |||||
| if err := <-result; err != io.EOF { | |||||
| t.Errorf("got %v (%T), want io.EOF", err, err) | |||||
| } | |||||
| } | |||||
| func TestMuxCloseWriteChannel(t *testing.T) { | |||||
| r, w, mux := channelPair(t) | |||||
| defer mux.Close() | |||||
| result := make(chan error, 1) | |||||
| go func() { | |||||
| var b [1024]byte | |||||
| _, err := r.Read(b[:]) | |||||
| result <- err | |||||
| }() | |||||
| if err := w.CloseWrite(); err != nil { | |||||
| t.Errorf("w.CloseWrite: %v", err) | |||||
| } | |||||
| if _, err := w.Write([]byte("hello")); err != io.EOF { | |||||
| t.Errorf("got err %v, want io.EOF after CloseWrite", err) | |||||
| } | |||||
| if err := <-result; err != io.EOF { | |||||
| t.Errorf("got %v (%T), want io.EOF", err, err) | |||||
| } | |||||
| } | |||||
| func TestMuxInvalidRecord(t *testing.T) { | |||||
| a, b := muxPair() | |||||
| defer a.Close() | |||||
| defer b.Close() | |||||
| packet := make([]byte, 1+4+4+1) | |||||
| packet[0] = msgChannelData | |||||
| marshalUint32(packet[1:], 29348723 /* invalid channel id */) | |||||
| marshalUint32(packet[5:], 1) | |||||
| packet[9] = 42 | |||||
| a.conn.writePacket(packet) | |||||
| go a.SendRequest("hello", false, nil) | |||||
| // 'a' wrote an invalid packet, so 'b' has exited. | |||||
| req, ok := <-b.incomingRequests | |||||
| if ok { | |||||
| t.Errorf("got request %#v after receiving invalid packet", req) | |||||
| } | |||||
| } | |||||
| func TestZeroWindowAdjust(t *testing.T) { | |||||
| a, b, mux := channelPair(t) | |||||
| defer a.Close() | |||||
| defer b.Close() | |||||
| defer mux.Close() | |||||
| go func() { | |||||
| io.WriteString(a, "hello") | |||||
| // bogus adjust. | |||||
| a.sendMessage(windowAdjustMsg{}) | |||||
| io.WriteString(a, "world") | |||||
| a.Close() | |||||
| }() | |||||
| want := "helloworld" | |||||
| c, _ := ioutil.ReadAll(b) | |||||
| if string(c) != want { | |||||
| t.Errorf("got %q want %q", c, want) | |||||
| } | |||||
| } | |||||
| func TestMuxMaxPacketSize(t *testing.T) { | |||||
| a, b, mux := channelPair(t) | |||||
| defer a.Close() | |||||
| defer b.Close() | |||||
| defer mux.Close() | |||||
| large := make([]byte, a.maxRemotePayload+1) | |||||
| packet := make([]byte, 1+4+4+1+len(large)) | |||||
| packet[0] = msgChannelData | |||||
| marshalUint32(packet[1:], a.remoteId) | |||||
| marshalUint32(packet[5:], uint32(len(large))) | |||||
| packet[9] = 42 | |||||
| if err := a.mux.conn.writePacket(packet); err != nil { | |||||
| t.Errorf("could not send packet") | |||||
| } | |||||
| go a.SendRequest("hello", false, nil) | |||||
| _, ok := <-b.incomingRequests | |||||
| if ok { | |||||
| t.Errorf("connection still alive after receiving large packet.") | |||||
| } | |||||
| } | |||||
| // Don't ship code with debug=true. | |||||
| func TestDebug(t *testing.T) { | |||||
| if debugMux { | |||||
| t.Error("mux debug switched on") | |||||
| } | |||||
| if debugHandshake { | |||||
| t.Error("handshake debug switched on") | |||||
| } | |||||
| } | |||||
| @@ -1,493 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "bytes" | |||||
| "errors" | |||||
| "fmt" | |||||
| "io" | |||||
| "net" | |||||
| ) | |||||
| // The Permissions type holds fine-grained permissions that are | |||||
| // specific to a user or a specific authentication method for a | |||||
| // user. Permissions, except for "source-address", must be enforced in | |||||
| // the server application layer, after successful authentication. The | |||||
| // Permissions are passed on in ServerConn so a server implementation | |||||
| // can honor them. | |||||
| type Permissions struct { | |||||
| // Critical options restrict default permissions. Common | |||||
| // restrictions are "source-address" and "force-command". If | |||||
| // the server cannot enforce the restriction, or does not | |||||
| // recognize it, the user should not authenticate. | |||||
| CriticalOptions map[string]string | |||||
| // Extensions are extra functionality that the server may | |||||
| // offer on authenticated connections. Common extensions are | |||||
| // "permit-agent-forwarding", "permit-X11-forwarding". Lack of | |||||
| // support for an extension does not preclude authenticating a | |||||
| // user. | |||||
| Extensions map[string]string | |||||
| } | |||||
| // ServerConfig holds server specific configuration data. | |||||
| type ServerConfig struct { | |||||
| // Config contains configuration shared between client and server. | |||||
| Config | |||||
| hostKeys []Signer | |||||
| // NoClientAuth is true if clients are allowed to connect without | |||||
| // authenticating. | |||||
| NoClientAuth bool | |||||
| // PasswordCallback, if non-nil, is called when a user | |||||
| // attempts to authenticate using a password. | |||||
| PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) | |||||
| // PublicKeyCallback, if non-nil, is called when a client attempts public | |||||
| // key authentication. It must return true if the given public key is | |||||
| // valid for the given user. For example, see CertChecker.Authenticate. | |||||
| PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error) | |||||
| // KeyboardInteractiveCallback, if non-nil, is called when | |||||
| // keyboard-interactive authentication is selected (RFC | |||||
| // 4256). The client object's Challenge function should be | |||||
| // used to query the user. The callback may offer multiple | |||||
| // Challenge rounds. To avoid information leaks, the client | |||||
| // should be presented a challenge even if the user is | |||||
| // unknown. | |||||
| KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) | |||||
| // AuthLogCallback, if non-nil, is called to log all authentication | |||||
| // attempts. | |||||
| AuthLogCallback func(conn ConnMetadata, method string, err error) | |||||
| // ServerVersion is the version identification string to | |||||
| // announce in the public handshake. | |||||
| // If empty, a reasonable default is used. | |||||
| ServerVersion string | |||||
| } | |||||
| // AddHostKey adds a private key as a host key. If an existing host | |||||
| // key exists with the same algorithm, it is overwritten. Each server | |||||
| // config must have at least one host key. | |||||
| func (s *ServerConfig) AddHostKey(key Signer) { | |||||
| for i, k := range s.hostKeys { | |||||
| if k.PublicKey().Type() == key.PublicKey().Type() { | |||||
| s.hostKeys[i] = key | |||||
| return | |||||
| } | |||||
| } | |||||
| s.hostKeys = append(s.hostKeys, key) | |||||
| } | |||||
| // cachedPubKey contains the results of querying whether a public key is | |||||
| // acceptable for a user. | |||||
| type cachedPubKey struct { | |||||
| user string | |||||
| pubKeyData []byte | |||||
| result error | |||||
| perms *Permissions | |||||
| } | |||||
| const maxCachedPubKeys = 16 | |||||
| // pubKeyCache caches tests for public keys. Since SSH clients | |||||
| // will query whether a public key is acceptable before attempting to | |||||
| // authenticate with it, we end up with duplicate queries for public | |||||
| // key validity. The cache only applies to a single ServerConn. | |||||
| type pubKeyCache struct { | |||||
| keys []cachedPubKey | |||||
| } | |||||
| // get returns the result for a given user/algo/key tuple. | |||||
| func (c *pubKeyCache) get(user string, pubKeyData []byte) (cachedPubKey, bool) { | |||||
| for _, k := range c.keys { | |||||
| if k.user == user && bytes.Equal(k.pubKeyData, pubKeyData) { | |||||
| return k, true | |||||
| } | |||||
| } | |||||
| return cachedPubKey{}, false | |||||
| } | |||||
| // add adds the given tuple to the cache. | |||||
| func (c *pubKeyCache) add(candidate cachedPubKey) { | |||||
| if len(c.keys) < maxCachedPubKeys { | |||||
| c.keys = append(c.keys, candidate) | |||||
| } | |||||
| } | |||||
| // ServerConn is an authenticated SSH connection, as seen from the | |||||
| // server | |||||
| type ServerConn struct { | |||||
| Conn | |||||
| // If the succeeding authentication callback returned a | |||||
| // non-nil Permissions pointer, it is stored here. | |||||
| Permissions *Permissions | |||||
| } | |||||
| // NewServerConn starts a new SSH server with c as the underlying | |||||
| // transport. It starts with a handshake and, if the handshake is | |||||
| // unsuccessful, it closes the connection and returns an error. The | |||||
| // Request and NewChannel channels must be serviced, or the connection | |||||
| // will hang. | |||||
| func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) { | |||||
| fullConf := *config | |||||
| fullConf.SetDefaults() | |||||
| s := &connection{ | |||||
| sshConn: sshConn{conn: c}, | |||||
| } | |||||
| perms, err := s.serverHandshake(&fullConf) | |||||
| if err != nil { | |||||
| c.Close() | |||||
| return nil, nil, nil, err | |||||
| } | |||||
| return &ServerConn{s, perms}, s.mux.incomingChannels, s.mux.incomingRequests, nil | |||||
| } | |||||
| // signAndMarshal signs the data with the appropriate algorithm, | |||||
| // and serializes the result in SSH wire format. | |||||
| func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) { | |||||
| sig, err := k.Sign(rand, data) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return Marshal(sig), nil | |||||
| } | |||||
| // handshake performs key exchange and user authentication. | |||||
| func (s *connection) serverHandshake(config *ServerConfig) (*Permissions, error) { | |||||
| if len(config.hostKeys) == 0 { | |||||
| return nil, errors.New("ssh: server has no host keys") | |||||
| } | |||||
| if !config.NoClientAuth && config.PasswordCallback == nil && config.PublicKeyCallback == nil && config.KeyboardInteractiveCallback == nil { | |||||
| return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") | |||||
| } | |||||
| if config.ServerVersion != "" { | |||||
| s.serverVersion = []byte(config.ServerVersion) | |||||
| } else { | |||||
| s.serverVersion = []byte(packageVersion) | |||||
| } | |||||
| var err error | |||||
| s.clientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| tr := newTransport(s.sshConn.conn, config.Rand, false /* not client */) | |||||
| s.transport = newServerTransport(tr, s.clientVersion, s.serverVersion, config) | |||||
| if err := s.transport.requestKeyChange(); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| if packet, err := s.transport.readPacket(); err != nil { | |||||
| return nil, err | |||||
| } else if packet[0] != msgNewKeys { | |||||
| return nil, unexpectedMessageError(msgNewKeys, packet[0]) | |||||
| } | |||||
| // We just did the key change, so the session ID is established. | |||||
| s.sessionID = s.transport.getSessionID() | |||||
| var packet []byte | |||||
| if packet, err = s.transport.readPacket(); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| var serviceRequest serviceRequestMsg | |||||
| if err = Unmarshal(packet, &serviceRequest); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| if serviceRequest.Service != serviceUserAuth { | |||||
| return nil, errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating") | |||||
| } | |||||
| serviceAccept := serviceAcceptMsg{ | |||||
| Service: serviceUserAuth, | |||||
| } | |||||
| if err := s.transport.writePacket(Marshal(&serviceAccept)); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| perms, err := s.serverAuthenticate(config) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| s.mux = newMux(s.transport) | |||||
| return perms, err | |||||
| } | |||||
| func isAcceptableAlgo(algo string) bool { | |||||
| switch algo { | |||||
| case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoECDSA521, | |||||
| CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECDSA384v01, CertAlgoECDSA521v01: | |||||
| return true | |||||
| } | |||||
| return false | |||||
| } | |||||
| func checkSourceAddress(addr net.Addr, sourceAddr string) error { | |||||
| if addr == nil { | |||||
| return errors.New("ssh: no address known for client, but source-address match required") | |||||
| } | |||||
| tcpAddr, ok := addr.(*net.TCPAddr) | |||||
| if !ok { | |||||
| return fmt.Errorf("ssh: remote address %v is not an TCP address when checking source-address match", addr) | |||||
| } | |||||
| if allowedIP := net.ParseIP(sourceAddr); allowedIP != nil { | |||||
| if bytes.Equal(allowedIP, tcpAddr.IP) { | |||||
| return nil | |||||
| } | |||||
| } else { | |||||
| _, ipNet, err := net.ParseCIDR(sourceAddr) | |||||
| if err != nil { | |||||
| return fmt.Errorf("ssh: error parsing source-address restriction %q: %v", sourceAddr, err) | |||||
| } | |||||
| if ipNet.Contains(tcpAddr.IP) { | |||||
| return nil | |||||
| } | |||||
| } | |||||
| return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr) | |||||
| } | |||||
| func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) { | |||||
| var err error | |||||
| var cache pubKeyCache | |||||
| var perms *Permissions | |||||
| userAuthLoop: | |||||
| for { | |||||
| var userAuthReq userAuthRequestMsg | |||||
| if packet, err := s.transport.readPacket(); err != nil { | |||||
| return nil, err | |||||
| } else if err = Unmarshal(packet, &userAuthReq); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| if userAuthReq.Service != serviceSSH { | |||||
| return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service) | |||||
| } | |||||
| s.user = userAuthReq.User | |||||
| perms = nil | |||||
| authErr := errors.New("no auth passed yet") | |||||
| switch userAuthReq.Method { | |||||
| case "none": | |||||
| if config.NoClientAuth { | |||||
| s.user = "" | |||||
| authErr = nil | |||||
| } | |||||
| case "password": | |||||
| if config.PasswordCallback == nil { | |||||
| authErr = errors.New("ssh: password auth not configured") | |||||
| break | |||||
| } | |||||
| payload := userAuthReq.Payload | |||||
| if len(payload) < 1 || payload[0] != 0 { | |||||
| return nil, parseError(msgUserAuthRequest) | |||||
| } | |||||
| payload = payload[1:] | |||||
| password, payload, ok := parseString(payload) | |||||
| if !ok || len(payload) > 0 { | |||||
| return nil, parseError(msgUserAuthRequest) | |||||
| } | |||||
| perms, authErr = config.PasswordCallback(s, password) | |||||
| case "keyboard-interactive": | |||||
| if config.KeyboardInteractiveCallback == nil { | |||||
| authErr = errors.New("ssh: keyboard-interactive auth not configubred") | |||||
| break | |||||
| } | |||||
| prompter := &sshClientKeyboardInteractive{s} | |||||
| perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge) | |||||
| case "publickey": | |||||
| if config.PublicKeyCallback == nil { | |||||
| authErr = errors.New("ssh: publickey auth not configured") | |||||
| break | |||||
| } | |||||
| payload := userAuthReq.Payload | |||||
| if len(payload) < 1 { | |||||
| return nil, parseError(msgUserAuthRequest) | |||||
| } | |||||
| isQuery := payload[0] == 0 | |||||
| payload = payload[1:] | |||||
| algoBytes, payload, ok := parseString(payload) | |||||
| if !ok { | |||||
| return nil, parseError(msgUserAuthRequest) | |||||
| } | |||||
| algo := string(algoBytes) | |||||
| if !isAcceptableAlgo(algo) { | |||||
| authErr = fmt.Errorf("ssh: algorithm %q not accepted", algo) | |||||
| break | |||||
| } | |||||
| pubKeyData, payload, ok := parseString(payload) | |||||
| if !ok { | |||||
| return nil, parseError(msgUserAuthRequest) | |||||
| } | |||||
| pubKey, err := ParsePublicKey(pubKeyData) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| candidate, ok := cache.get(s.user, pubKeyData) | |||||
| if !ok { | |||||
| candidate.user = s.user | |||||
| candidate.pubKeyData = pubKeyData | |||||
| candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey) | |||||
| if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" { | |||||
| candidate.result = checkSourceAddress( | |||||
| s.RemoteAddr(), | |||||
| candidate.perms.CriticalOptions[sourceAddressCriticalOption]) | |||||
| } | |||||
| cache.add(candidate) | |||||
| } | |||||
| if isQuery { | |||||
| // The client can query if the given public key | |||||
| // would be okay. | |||||
| if len(payload) > 0 { | |||||
| return nil, parseError(msgUserAuthRequest) | |||||
| } | |||||
| if candidate.result == nil { | |||||
| okMsg := userAuthPubKeyOkMsg{ | |||||
| Algo: algo, | |||||
| PubKey: pubKeyData, | |||||
| } | |||||
| if err = s.transport.writePacket(Marshal(&okMsg)); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| continue userAuthLoop | |||||
| } | |||||
| authErr = candidate.result | |||||
| } else { | |||||
| sig, payload, ok := parseSignature(payload) | |||||
| if !ok || len(payload) > 0 { | |||||
| return nil, parseError(msgUserAuthRequest) | |||||
| } | |||||
| // Ensure the public key algo and signature algo | |||||
| // are supported. Compare the private key | |||||
| // algorithm name that corresponds to algo with | |||||
| // sig.Format. This is usually the same, but | |||||
| // for certs, the names differ. | |||||
| if !isAcceptableAlgo(sig.Format) { | |||||
| break | |||||
| } | |||||
| signedData := buildDataSignedForAuth(s.transport.getSessionID(), userAuthReq, algoBytes, pubKeyData) | |||||
| if err := pubKey.Verify(signedData, sig); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| authErr = candidate.result | |||||
| perms = candidate.perms | |||||
| } | |||||
| default: | |||||
| authErr = fmt.Errorf("ssh: unknown method %q", userAuthReq.Method) | |||||
| } | |||||
| if config.AuthLogCallback != nil { | |||||
| config.AuthLogCallback(s, userAuthReq.Method, authErr) | |||||
| } | |||||
| if authErr == nil { | |||||
| break userAuthLoop | |||||
| } | |||||
| var failureMsg userAuthFailureMsg | |||||
| if config.PasswordCallback != nil { | |||||
| failureMsg.Methods = append(failureMsg.Methods, "password") | |||||
| } | |||||
| if config.PublicKeyCallback != nil { | |||||
| failureMsg.Methods = append(failureMsg.Methods, "publickey") | |||||
| } | |||||
| if config.KeyboardInteractiveCallback != nil { | |||||
| failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive") | |||||
| } | |||||
| if len(failureMsg.Methods) == 0 { | |||||
| return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") | |||||
| } | |||||
| if err = s.transport.writePacket(Marshal(&failureMsg)); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| } | |||||
| if err = s.transport.writePacket([]byte{msgUserAuthSuccess}); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return perms, nil | |||||
| } | |||||
| // sshClientKeyboardInteractive implements a ClientKeyboardInteractive by | |||||
| // asking the client on the other side of a ServerConn. | |||||
| type sshClientKeyboardInteractive struct { | |||||
| *connection | |||||
| } | |||||
| func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, questions []string, echos []bool) (answers []string, err error) { | |||||
| if len(questions) != len(echos) { | |||||
| return nil, errors.New("ssh: echos and questions must have equal length") | |||||
| } | |||||
| var prompts []byte | |||||
| for i := range questions { | |||||
| prompts = appendString(prompts, questions[i]) | |||||
| prompts = appendBool(prompts, echos[i]) | |||||
| } | |||||
| if err := c.transport.writePacket(Marshal(&userAuthInfoRequestMsg{ | |||||
| Instruction: instruction, | |||||
| NumPrompts: uint32(len(questions)), | |||||
| Prompts: prompts, | |||||
| })); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| packet, err := c.transport.readPacket() | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| if packet[0] != msgUserAuthInfoResponse { | |||||
| return nil, unexpectedMessageError(msgUserAuthInfoResponse, packet[0]) | |||||
| } | |||||
| packet = packet[1:] | |||||
| n, packet, ok := parseUint32(packet) | |||||
| if !ok || int(n) != len(questions) { | |||||
| return nil, parseError(msgUserAuthInfoResponse) | |||||
| } | |||||
| for i := uint32(0); i < n; i++ { | |||||
| ans, rest, ok := parseString(packet) | |||||
| if !ok { | |||||
| return nil, parseError(msgUserAuthInfoResponse) | |||||
| } | |||||
| answers = append(answers, string(ans)) | |||||
| packet = rest | |||||
| } | |||||
| if len(packet) != 0 { | |||||
| return nil, errors.New("ssh: junk at end of message") | |||||
| } | |||||
| return answers, nil | |||||
| } | |||||
| @@ -1,605 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| // Session implements an interactive session described in | |||||
| // "RFC 4254, section 6". | |||||
| import ( | |||||
| "bytes" | |||||
| "errors" | |||||
| "fmt" | |||||
| "io" | |||||
| "io/ioutil" | |||||
| "sync" | |||||
| ) | |||||
| type Signal string | |||||
| // POSIX signals as listed in RFC 4254 Section 6.10. | |||||
| const ( | |||||
| SIGABRT Signal = "ABRT" | |||||
| SIGALRM Signal = "ALRM" | |||||
| SIGFPE Signal = "FPE" | |||||
| SIGHUP Signal = "HUP" | |||||
| SIGILL Signal = "ILL" | |||||
| SIGINT Signal = "INT" | |||||
| SIGKILL Signal = "KILL" | |||||
| SIGPIPE Signal = "PIPE" | |||||
| SIGQUIT Signal = "QUIT" | |||||
| SIGSEGV Signal = "SEGV" | |||||
| SIGTERM Signal = "TERM" | |||||
| SIGUSR1 Signal = "USR1" | |||||
| SIGUSR2 Signal = "USR2" | |||||
| ) | |||||
| var signals = map[Signal]int{ | |||||
| SIGABRT: 6, | |||||
| SIGALRM: 14, | |||||
| SIGFPE: 8, | |||||
| SIGHUP: 1, | |||||
| SIGILL: 4, | |||||
| SIGINT: 2, | |||||
| SIGKILL: 9, | |||||
| SIGPIPE: 13, | |||||
| SIGQUIT: 3, | |||||
| SIGSEGV: 11, | |||||
| SIGTERM: 15, | |||||
| } | |||||
| type TerminalModes map[uint8]uint32 | |||||
| // POSIX terminal mode flags as listed in RFC 4254 Section 8. | |||||
| const ( | |||||
| tty_OP_END = 0 | |||||
| VINTR = 1 | |||||
| VQUIT = 2 | |||||
| VERASE = 3 | |||||
| VKILL = 4 | |||||
| VEOF = 5 | |||||
| VEOL = 6 | |||||
| VEOL2 = 7 | |||||
| VSTART = 8 | |||||
| VSTOP = 9 | |||||
| VSUSP = 10 | |||||
| VDSUSP = 11 | |||||
| VREPRINT = 12 | |||||
| VWERASE = 13 | |||||
| VLNEXT = 14 | |||||
| VFLUSH = 15 | |||||
| VSWTCH = 16 | |||||
| VSTATUS = 17 | |||||
| VDISCARD = 18 | |||||
| IGNPAR = 30 | |||||
| PARMRK = 31 | |||||
| INPCK = 32 | |||||
| ISTRIP = 33 | |||||
| INLCR = 34 | |||||
| IGNCR = 35 | |||||
| ICRNL = 36 | |||||
| IUCLC = 37 | |||||
| IXON = 38 | |||||
| IXANY = 39 | |||||
| IXOFF = 40 | |||||
| IMAXBEL = 41 | |||||
| ISIG = 50 | |||||
| ICANON = 51 | |||||
| XCASE = 52 | |||||
| ECHO = 53 | |||||
| ECHOE = 54 | |||||
| ECHOK = 55 | |||||
| ECHONL = 56 | |||||
| NOFLSH = 57 | |||||
| TOSTOP = 58 | |||||
| IEXTEN = 59 | |||||
| ECHOCTL = 60 | |||||
| ECHOKE = 61 | |||||
| PENDIN = 62 | |||||
| OPOST = 70 | |||||
| OLCUC = 71 | |||||
| ONLCR = 72 | |||||
| OCRNL = 73 | |||||
| ONOCR = 74 | |||||
| ONLRET = 75 | |||||
| CS7 = 90 | |||||
| CS8 = 91 | |||||
| PARENB = 92 | |||||
| PARODD = 93 | |||||
| TTY_OP_ISPEED = 128 | |||||
| TTY_OP_OSPEED = 129 | |||||
| ) | |||||
| // A Session represents a connection to a remote command or shell. | |||||
| type Session struct { | |||||
| // Stdin specifies the remote process's standard input. | |||||
| // If Stdin is nil, the remote process reads from an empty | |||||
| // bytes.Buffer. | |||||
| Stdin io.Reader | |||||
| // Stdout and Stderr specify the remote process's standard | |||||
| // output and error. | |||||
| // | |||||
| // If either is nil, Run connects the corresponding file | |||||
| // descriptor to an instance of ioutil.Discard. There is a | |||||
| // fixed amount of buffering that is shared for the two streams. | |||||
| // If either blocks it may eventually cause the remote | |||||
| // command to block. | |||||
| Stdout io.Writer | |||||
| Stderr io.Writer | |||||
| ch Channel // the channel backing this session | |||||
| started bool // true once Start, Run or Shell is invoked. | |||||
| copyFuncs []func() error | |||||
| errors chan error // one send per copyFunc | |||||
| // true if pipe method is active | |||||
| stdinpipe, stdoutpipe, stderrpipe bool | |||||
| // stdinPipeWriter is non-nil if StdinPipe has not been called | |||||
| // and Stdin was specified by the user; it is the write end of | |||||
| // a pipe connecting Session.Stdin to the stdin channel. | |||||
| stdinPipeWriter io.WriteCloser | |||||
| exitStatus chan error | |||||
| } | |||||
| // SendRequest sends an out-of-band channel request on the SSH channel | |||||
| // underlying the session. | |||||
| func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { | |||||
| return s.ch.SendRequest(name, wantReply, payload) | |||||
| } | |||||
| func (s *Session) Close() error { | |||||
| return s.ch.Close() | |||||
| } | |||||
| // RFC 4254 Section 6.4. | |||||
| type setenvRequest struct { | |||||
| Name string | |||||
| Value string | |||||
| } | |||||
| // Setenv sets an environment variable that will be applied to any | |||||
| // command executed by Shell or Run. | |||||
| func (s *Session) Setenv(name, value string) error { | |||||
| msg := setenvRequest{ | |||||
| Name: name, | |||||
| Value: value, | |||||
| } | |||||
| ok, err := s.ch.SendRequest("env", true, Marshal(&msg)) | |||||
| if err == nil && !ok { | |||||
| err = errors.New("ssh: setenv failed") | |||||
| } | |||||
| return err | |||||
| } | |||||
| // RFC 4254 Section 6.2. | |||||
| type ptyRequestMsg struct { | |||||
| Term string | |||||
| Columns uint32 | |||||
| Rows uint32 | |||||
| Width uint32 | |||||
| Height uint32 | |||||
| Modelist string | |||||
| } | |||||
| // RequestPty requests the association of a pty with the session on the remote host. | |||||
| func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error { | |||||
| var tm []byte | |||||
| for k, v := range termmodes { | |||||
| kv := struct { | |||||
| Key byte | |||||
| Val uint32 | |||||
| }{k, v} | |||||
| tm = append(tm, Marshal(&kv)...) | |||||
| } | |||||
| tm = append(tm, tty_OP_END) | |||||
| req := ptyRequestMsg{ | |||||
| Term: term, | |||||
| Columns: uint32(w), | |||||
| Rows: uint32(h), | |||||
| Width: uint32(w * 8), | |||||
| Height: uint32(h * 8), | |||||
| Modelist: string(tm), | |||||
| } | |||||
| ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req)) | |||||
| if err == nil && !ok { | |||||
| err = errors.New("ssh: pty-req failed") | |||||
| } | |||||
| return err | |||||
| } | |||||
| // RFC 4254 Section 6.5. | |||||
| type subsystemRequestMsg struct { | |||||
| Subsystem string | |||||
| } | |||||
| // RequestSubsystem requests the association of a subsystem with the session on the remote host. | |||||
| // A subsystem is a predefined command that runs in the background when the ssh session is initiated | |||||
| func (s *Session) RequestSubsystem(subsystem string) error { | |||||
| msg := subsystemRequestMsg{ | |||||
| Subsystem: subsystem, | |||||
| } | |||||
| ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg)) | |||||
| if err == nil && !ok { | |||||
| err = errors.New("ssh: subsystem request failed") | |||||
| } | |||||
| return err | |||||
| } | |||||
| // RFC 4254 Section 6.9. | |||||
| type signalMsg struct { | |||||
| Signal string | |||||
| } | |||||
| // Signal sends the given signal to the remote process. | |||||
| // sig is one of the SIG* constants. | |||||
| func (s *Session) Signal(sig Signal) error { | |||||
| msg := signalMsg{ | |||||
| Signal: string(sig), | |||||
| } | |||||
| _, err := s.ch.SendRequest("signal", false, Marshal(&msg)) | |||||
| return err | |||||
| } | |||||
| // RFC 4254 Section 6.5. | |||||
| type execMsg struct { | |||||
| Command string | |||||
| } | |||||
| // Start runs cmd on the remote host. Typically, the remote | |||||
| // server passes cmd to the shell for interpretation. | |||||
| // A Session only accepts one call to Run, Start or Shell. | |||||
| func (s *Session) Start(cmd string) error { | |||||
| if s.started { | |||||
| return errors.New("ssh: session already started") | |||||
| } | |||||
| req := execMsg{ | |||||
| Command: cmd, | |||||
| } | |||||
| ok, err := s.ch.SendRequest("exec", true, Marshal(&req)) | |||||
| if err == nil && !ok { | |||||
| err = fmt.Errorf("ssh: command %v failed", cmd) | |||||
| } | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| return s.start() | |||||
| } | |||||
| // Run runs cmd on the remote host. Typically, the remote | |||||
| // server passes cmd to the shell for interpretation. | |||||
| // A Session only accepts one call to Run, Start, Shell, Output, | |||||
| // or CombinedOutput. | |||||
| // | |||||
| // The returned error is nil if the command runs, has no problems | |||||
| // copying stdin, stdout, and stderr, and exits with a zero exit | |||||
| // status. | |||||
| // | |||||
| // If the command fails to run or doesn't complete successfully, the | |||||
| // error is of type *ExitError. Other error types may be | |||||
| // returned for I/O problems. | |||||
| func (s *Session) Run(cmd string) error { | |||||
| err := s.Start(cmd) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| return s.Wait() | |||||
| } | |||||
| // Output runs cmd on the remote host and returns its standard output. | |||||
| func (s *Session) Output(cmd string) ([]byte, error) { | |||||
| if s.Stdout != nil { | |||||
| return nil, errors.New("ssh: Stdout already set") | |||||
| } | |||||
| var b bytes.Buffer | |||||
| s.Stdout = &b | |||||
| err := s.Run(cmd) | |||||
| return b.Bytes(), err | |||||
| } | |||||
| type singleWriter struct { | |||||
| b bytes.Buffer | |||||
| mu sync.Mutex | |||||
| } | |||||
| func (w *singleWriter) Write(p []byte) (int, error) { | |||||
| w.mu.Lock() | |||||
| defer w.mu.Unlock() | |||||
| return w.b.Write(p) | |||||
| } | |||||
| // CombinedOutput runs cmd on the remote host and returns its combined | |||||
| // standard output and standard error. | |||||
| func (s *Session) CombinedOutput(cmd string) ([]byte, error) { | |||||
| if s.Stdout != nil { | |||||
| return nil, errors.New("ssh: Stdout already set") | |||||
| } | |||||
| if s.Stderr != nil { | |||||
| return nil, errors.New("ssh: Stderr already set") | |||||
| } | |||||
| var b singleWriter | |||||
| s.Stdout = &b | |||||
| s.Stderr = &b | |||||
| err := s.Run(cmd) | |||||
| return b.b.Bytes(), err | |||||
| } | |||||
| // Shell starts a login shell on the remote host. A Session only | |||||
| // accepts one call to Run, Start, Shell, Output, or CombinedOutput. | |||||
| func (s *Session) Shell() error { | |||||
| if s.started { | |||||
| return errors.New("ssh: session already started") | |||||
| } | |||||
| ok, err := s.ch.SendRequest("shell", true, nil) | |||||
| if err == nil && !ok { | |||||
| return fmt.Errorf("ssh: cound not start shell") | |||||
| } | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| return s.start() | |||||
| } | |||||
| func (s *Session) start() error { | |||||
| s.started = true | |||||
| type F func(*Session) | |||||
| for _, setupFd := range []F{(*Session).stdin, (*Session).stdout, (*Session).stderr} { | |||||
| setupFd(s) | |||||
| } | |||||
| s.errors = make(chan error, len(s.copyFuncs)) | |||||
| for _, fn := range s.copyFuncs { | |||||
| go func(fn func() error) { | |||||
| s.errors <- fn() | |||||
| }(fn) | |||||
| } | |||||
| return nil | |||||
| } | |||||
| // Wait waits for the remote command to exit. | |||||
| // | |||||
| // The returned error is nil if the command runs, has no problems | |||||
| // copying stdin, stdout, and stderr, and exits with a zero exit | |||||
| // status. | |||||
| // | |||||
| // If the command fails to run or doesn't complete successfully, the | |||||
| // error is of type *ExitError. Other error types may be | |||||
| // returned for I/O problems. | |||||
| func (s *Session) Wait() error { | |||||
| if !s.started { | |||||
| return errors.New("ssh: session not started") | |||||
| } | |||||
| waitErr := <-s.exitStatus | |||||
| if s.stdinPipeWriter != nil { | |||||
| s.stdinPipeWriter.Close() | |||||
| } | |||||
| var copyError error | |||||
| for _ = range s.copyFuncs { | |||||
| if err := <-s.errors; err != nil && copyError == nil { | |||||
| copyError = err | |||||
| } | |||||
| } | |||||
| if waitErr != nil { | |||||
| return waitErr | |||||
| } | |||||
| return copyError | |||||
| } | |||||
| func (s *Session) wait(reqs <-chan *Request) error { | |||||
| wm := Waitmsg{status: -1} | |||||
| // Wait for msg channel to be closed before returning. | |||||
| for msg := range reqs { | |||||
| switch msg.Type { | |||||
| case "exit-status": | |||||
| d := msg.Payload | |||||
| wm.status = int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3]) | |||||
| case "exit-signal": | |||||
| var sigval struct { | |||||
| Signal string | |||||
| CoreDumped bool | |||||
| Error string | |||||
| Lang string | |||||
| } | |||||
| if err := Unmarshal(msg.Payload, &sigval); err != nil { | |||||
| return err | |||||
| } | |||||
| // Must sanitize strings? | |||||
| wm.signal = sigval.Signal | |||||
| wm.msg = sigval.Error | |||||
| wm.lang = sigval.Lang | |||||
| default: | |||||
| // This handles keepalives and matches | |||||
| // OpenSSH's behaviour. | |||||
| if msg.WantReply { | |||||
| msg.Reply(false, nil) | |||||
| } | |||||
| } | |||||
| } | |||||
| if wm.status == 0 { | |||||
| return nil | |||||
| } | |||||
| if wm.status == -1 { | |||||
| // exit-status was never sent from server | |||||
| if wm.signal == "" { | |||||
| return errors.New("wait: remote command exited without exit status or exit signal") | |||||
| } | |||||
| wm.status = 128 | |||||
| if _, ok := signals[Signal(wm.signal)]; ok { | |||||
| wm.status += signals[Signal(wm.signal)] | |||||
| } | |||||
| } | |||||
| return &ExitError{wm} | |||||
| } | |||||
| func (s *Session) stdin() { | |||||
| if s.stdinpipe { | |||||
| return | |||||
| } | |||||
| var stdin io.Reader | |||||
| if s.Stdin == nil { | |||||
| stdin = new(bytes.Buffer) | |||||
| } else { | |||||
| r, w := io.Pipe() | |||||
| go func() { | |||||
| _, err := io.Copy(w, s.Stdin) | |||||
| w.CloseWithError(err) | |||||
| }() | |||||
| stdin, s.stdinPipeWriter = r, w | |||||
| } | |||||
| s.copyFuncs = append(s.copyFuncs, func() error { | |||||
| _, err := io.Copy(s.ch, stdin) | |||||
| if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF { | |||||
| err = err1 | |||||
| } | |||||
| return err | |||||
| }) | |||||
| } | |||||
| func (s *Session) stdout() { | |||||
| if s.stdoutpipe { | |||||
| return | |||||
| } | |||||
| if s.Stdout == nil { | |||||
| s.Stdout = ioutil.Discard | |||||
| } | |||||
| s.copyFuncs = append(s.copyFuncs, func() error { | |||||
| _, err := io.Copy(s.Stdout, s.ch) | |||||
| return err | |||||
| }) | |||||
| } | |||||
| func (s *Session) stderr() { | |||||
| if s.stderrpipe { | |||||
| return | |||||
| } | |||||
| if s.Stderr == nil { | |||||
| s.Stderr = ioutil.Discard | |||||
| } | |||||
| s.copyFuncs = append(s.copyFuncs, func() error { | |||||
| _, err := io.Copy(s.Stderr, s.ch.Stderr()) | |||||
| return err | |||||
| }) | |||||
| } | |||||
| // sessionStdin reroutes Close to CloseWrite. | |||||
| type sessionStdin struct { | |||||
| io.Writer | |||||
| ch Channel | |||||
| } | |||||
| func (s *sessionStdin) Close() error { | |||||
| return s.ch.CloseWrite() | |||||
| } | |||||
| // StdinPipe returns a pipe that will be connected to the | |||||
| // remote command's standard input when the command starts. | |||||
| func (s *Session) StdinPipe() (io.WriteCloser, error) { | |||||
| if s.Stdin != nil { | |||||
| return nil, errors.New("ssh: Stdin already set") | |||||
| } | |||||
| if s.started { | |||||
| return nil, errors.New("ssh: StdinPipe after process started") | |||||
| } | |||||
| s.stdinpipe = true | |||||
| return &sessionStdin{s.ch, s.ch}, nil | |||||
| } | |||||
| // StdoutPipe returns a pipe that will be connected to the | |||||
| // remote command's standard output when the command starts. | |||||
| // There is a fixed amount of buffering that is shared between | |||||
| // stdout and stderr streams. If the StdoutPipe reader is | |||||
| // not serviced fast enough it may eventually cause the | |||||
| // remote command to block. | |||||
| func (s *Session) StdoutPipe() (io.Reader, error) { | |||||
| if s.Stdout != nil { | |||||
| return nil, errors.New("ssh: Stdout already set") | |||||
| } | |||||
| if s.started { | |||||
| return nil, errors.New("ssh: StdoutPipe after process started") | |||||
| } | |||||
| s.stdoutpipe = true | |||||
| return s.ch, nil | |||||
| } | |||||
| // StderrPipe returns a pipe that will be connected to the | |||||
| // remote command's standard error when the command starts. | |||||
| // There is a fixed amount of buffering that is shared between | |||||
| // stdout and stderr streams. If the StderrPipe reader is | |||||
| // not serviced fast enough it may eventually cause the | |||||
| // remote command to block. | |||||
| func (s *Session) StderrPipe() (io.Reader, error) { | |||||
| if s.Stderr != nil { | |||||
| return nil, errors.New("ssh: Stderr already set") | |||||
| } | |||||
| if s.started { | |||||
| return nil, errors.New("ssh: StderrPipe after process started") | |||||
| } | |||||
| s.stderrpipe = true | |||||
| return s.ch.Stderr(), nil | |||||
| } | |||||
| // newSession returns a new interactive session on the remote host. | |||||
| func newSession(ch Channel, reqs <-chan *Request) (*Session, error) { | |||||
| s := &Session{ | |||||
| ch: ch, | |||||
| } | |||||
| s.exitStatus = make(chan error, 1) | |||||
| go func() { | |||||
| s.exitStatus <- s.wait(reqs) | |||||
| }() | |||||
| return s, nil | |||||
| } | |||||
| // An ExitError reports unsuccessful completion of a remote command. | |||||
| type ExitError struct { | |||||
| Waitmsg | |||||
| } | |||||
| func (e *ExitError) Error() string { | |||||
| return e.Waitmsg.String() | |||||
| } | |||||
| // Waitmsg stores the information about an exited remote command | |||||
| // as reported by Wait. | |||||
| type Waitmsg struct { | |||||
| status int | |||||
| signal string | |||||
| msg string | |||||
| lang string | |||||
| } | |||||
| // ExitStatus returns the exit status of the remote command. | |||||
| func (w Waitmsg) ExitStatus() int { | |||||
| return w.status | |||||
| } | |||||
| // Signal returns the exit signal of the remote command if | |||||
| // it was terminated violently. | |||||
| func (w Waitmsg) Signal() string { | |||||
| return w.signal | |||||
| } | |||||
| // Msg returns the exit message given by the remote command | |||||
| func (w Waitmsg) Msg() string { | |||||
| return w.msg | |||||
| } | |||||
| // Lang returns the language tag. See RFC 3066 | |||||
| func (w Waitmsg) Lang() string { | |||||
| return w.lang | |||||
| } | |||||
| func (w Waitmsg) String() string { | |||||
| return fmt.Sprintf("Process exited with: %v. Reason was: %v (%v)", w.status, w.msg, w.signal) | |||||
| } | |||||
| @@ -1,774 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| // Session tests. | |||||
| import ( | |||||
| "bytes" | |||||
| crypto_rand "crypto/rand" | |||||
| "errors" | |||||
| "io" | |||||
| "io/ioutil" | |||||
| "math/rand" | |||||
| "net" | |||||
| "testing" | |||||
| "github.com/gogits/gogs/modules/crypto/ssh/terminal" | |||||
| ) | |||||
| type serverType func(Channel, <-chan *Request, *testing.T) | |||||
| // dial constructs a new test server and returns a *ClientConn. | |||||
| func dial(handler serverType, t *testing.T) *Client { | |||||
| c1, c2, err := netPipe() | |||||
| if err != nil { | |||||
| t.Fatalf("netPipe: %v", err) | |||||
| } | |||||
| go func() { | |||||
| defer c1.Close() | |||||
| conf := ServerConfig{ | |||||
| NoClientAuth: true, | |||||
| } | |||||
| conf.AddHostKey(testSigners["rsa"]) | |||||
| _, chans, reqs, err := NewServerConn(c1, &conf) | |||||
| if err != nil { | |||||
| t.Fatalf("Unable to handshake: %v", err) | |||||
| } | |||||
| go DiscardRequests(reqs) | |||||
| for newCh := range chans { | |||||
| if newCh.ChannelType() != "session" { | |||||
| newCh.Reject(UnknownChannelType, "unknown channel type") | |||||
| continue | |||||
| } | |||||
| ch, inReqs, err := newCh.Accept() | |||||
| if err != nil { | |||||
| t.Errorf("Accept: %v", err) | |||||
| continue | |||||
| } | |||||
| go func() { | |||||
| handler(ch, inReqs, t) | |||||
| }() | |||||
| } | |||||
| }() | |||||
| config := &ClientConfig{ | |||||
| User: "testuser", | |||||
| } | |||||
| conn, chans, reqs, err := NewClientConn(c2, "", config) | |||||
| if err != nil { | |||||
| t.Fatalf("unable to dial remote side: %v", err) | |||||
| } | |||||
| return NewClient(conn, chans, reqs) | |||||
| } | |||||
| // Test a simple string is returned to session.Stdout. | |||||
| func TestSessionShell(t *testing.T) { | |||||
| conn := dial(shellHandler, t) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("Unable to request new session: %v", err) | |||||
| } | |||||
| defer session.Close() | |||||
| stdout := new(bytes.Buffer) | |||||
| session.Stdout = stdout | |||||
| if err := session.Shell(); err != nil { | |||||
| t.Fatalf("Unable to execute command: %s", err) | |||||
| } | |||||
| if err := session.Wait(); err != nil { | |||||
| t.Fatalf("Remote command did not exit cleanly: %v", err) | |||||
| } | |||||
| actual := stdout.String() | |||||
| if actual != "golang" { | |||||
| t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) | |||||
| } | |||||
| } | |||||
| // TODO(dfc) add support for Std{in,err}Pipe when the Server supports it. | |||||
| // Test a simple string is returned via StdoutPipe. | |||||
| func TestSessionStdoutPipe(t *testing.T) { | |||||
| conn := dial(shellHandler, t) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("Unable to request new session: %v", err) | |||||
| } | |||||
| defer session.Close() | |||||
| stdout, err := session.StdoutPipe() | |||||
| if err != nil { | |||||
| t.Fatalf("Unable to request StdoutPipe(): %v", err) | |||||
| } | |||||
| var buf bytes.Buffer | |||||
| if err := session.Shell(); err != nil { | |||||
| t.Fatalf("Unable to execute command: %v", err) | |||||
| } | |||||
| done := make(chan bool, 1) | |||||
| go func() { | |||||
| if _, err := io.Copy(&buf, stdout); err != nil { | |||||
| t.Errorf("Copy of stdout failed: %v", err) | |||||
| } | |||||
| done <- true | |||||
| }() | |||||
| if err := session.Wait(); err != nil { | |||||
| t.Fatalf("Remote command did not exit cleanly: %v", err) | |||||
| } | |||||
| <-done | |||||
| actual := buf.String() | |||||
| if actual != "golang" { | |||||
| t.Fatalf("Remote shell did not return expected string: expected=golang, actual=%s", actual) | |||||
| } | |||||
| } | |||||
| // Test that a simple string is returned via the Output helper, | |||||
| // and that stderr is discarded. | |||||
| func TestSessionOutput(t *testing.T) { | |||||
| conn := dial(fixedOutputHandler, t) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("Unable to request new session: %v", err) | |||||
| } | |||||
| defer session.Close() | |||||
| buf, err := session.Output("") // cmd is ignored by fixedOutputHandler | |||||
| if err != nil { | |||||
| t.Error("Remote command did not exit cleanly:", err) | |||||
| } | |||||
| w := "this-is-stdout." | |||||
| g := string(buf) | |||||
| if g != w { | |||||
| t.Error("Remote command did not return expected string:") | |||||
| t.Logf("want %q", w) | |||||
| t.Logf("got %q", g) | |||||
| } | |||||
| } | |||||
| // Test that both stdout and stderr are returned | |||||
| // via the CombinedOutput helper. | |||||
| func TestSessionCombinedOutput(t *testing.T) { | |||||
| conn := dial(fixedOutputHandler, t) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("Unable to request new session: %v", err) | |||||
| } | |||||
| defer session.Close() | |||||
| buf, err := session.CombinedOutput("") // cmd is ignored by fixedOutputHandler | |||||
| if err != nil { | |||||
| t.Error("Remote command did not exit cleanly:", err) | |||||
| } | |||||
| const stdout = "this-is-stdout." | |||||
| const stderr = "this-is-stderr." | |||||
| g := string(buf) | |||||
| if g != stdout+stderr && g != stderr+stdout { | |||||
| t.Error("Remote command did not return expected string:") | |||||
| t.Logf("want %q, or %q", stdout+stderr, stderr+stdout) | |||||
| t.Logf("got %q", g) | |||||
| } | |||||
| } | |||||
| // Test non-0 exit status is returned correctly. | |||||
| func TestExitStatusNonZero(t *testing.T) { | |||||
| conn := dial(exitStatusNonZeroHandler, t) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("Unable to request new session: %v", err) | |||||
| } | |||||
| defer session.Close() | |||||
| if err := session.Shell(); err != nil { | |||||
| t.Fatalf("Unable to execute command: %v", err) | |||||
| } | |||||
| err = session.Wait() | |||||
| if err == nil { | |||||
| t.Fatalf("expected command to fail but it didn't") | |||||
| } | |||||
| e, ok := err.(*ExitError) | |||||
| if !ok { | |||||
| t.Fatalf("expected *ExitError but got %T", err) | |||||
| } | |||||
| if e.ExitStatus() != 15 { | |||||
| t.Fatalf("expected command to exit with 15 but got %v", e.ExitStatus()) | |||||
| } | |||||
| } | |||||
| // Test 0 exit status is returned correctly. | |||||
| func TestExitStatusZero(t *testing.T) { | |||||
| conn := dial(exitStatusZeroHandler, t) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("Unable to request new session: %v", err) | |||||
| } | |||||
| defer session.Close() | |||||
| if err := session.Shell(); err != nil { | |||||
| t.Fatalf("Unable to execute command: %v", err) | |||||
| } | |||||
| err = session.Wait() | |||||
| if err != nil { | |||||
| t.Fatalf("expected nil but got %v", err) | |||||
| } | |||||
| } | |||||
| // Test exit signal and status are both returned correctly. | |||||
| func TestExitSignalAndStatus(t *testing.T) { | |||||
| conn := dial(exitSignalAndStatusHandler, t) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("Unable to request new session: %v", err) | |||||
| } | |||||
| defer session.Close() | |||||
| if err := session.Shell(); err != nil { | |||||
| t.Fatalf("Unable to execute command: %v", err) | |||||
| } | |||||
| err = session.Wait() | |||||
| if err == nil { | |||||
| t.Fatalf("expected command to fail but it didn't") | |||||
| } | |||||
| e, ok := err.(*ExitError) | |||||
| if !ok { | |||||
| t.Fatalf("expected *ExitError but got %T", err) | |||||
| } | |||||
| if e.Signal() != "TERM" || e.ExitStatus() != 15 { | |||||
| t.Fatalf("expected command to exit with signal TERM and status 15 but got signal %s and status %v", e.Signal(), e.ExitStatus()) | |||||
| } | |||||
| } | |||||
| // Test exit signal and status are both returned correctly. | |||||
| func TestKnownExitSignalOnly(t *testing.T) { | |||||
| conn := dial(exitSignalHandler, t) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("Unable to request new session: %v", err) | |||||
| } | |||||
| defer session.Close() | |||||
| if err := session.Shell(); err != nil { | |||||
| t.Fatalf("Unable to execute command: %v", err) | |||||
| } | |||||
| err = session.Wait() | |||||
| if err == nil { | |||||
| t.Fatalf("expected command to fail but it didn't") | |||||
| } | |||||
| e, ok := err.(*ExitError) | |||||
| if !ok { | |||||
| t.Fatalf("expected *ExitError but got %T", err) | |||||
| } | |||||
| if e.Signal() != "TERM" || e.ExitStatus() != 143 { | |||||
| t.Fatalf("expected command to exit with signal TERM and status 143 but got signal %s and status %v", e.Signal(), e.ExitStatus()) | |||||
| } | |||||
| } | |||||
| // Test exit signal and status are both returned correctly. | |||||
| func TestUnknownExitSignal(t *testing.T) { | |||||
| conn := dial(exitSignalUnknownHandler, t) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("Unable to request new session: %v", err) | |||||
| } | |||||
| defer session.Close() | |||||
| if err := session.Shell(); err != nil { | |||||
| t.Fatalf("Unable to execute command: %v", err) | |||||
| } | |||||
| err = session.Wait() | |||||
| if err == nil { | |||||
| t.Fatalf("expected command to fail but it didn't") | |||||
| } | |||||
| e, ok := err.(*ExitError) | |||||
| if !ok { | |||||
| t.Fatalf("expected *ExitError but got %T", err) | |||||
| } | |||||
| if e.Signal() != "SYS" || e.ExitStatus() != 128 { | |||||
| t.Fatalf("expected command to exit with signal SYS and status 128 but got signal %s and status %v", e.Signal(), e.ExitStatus()) | |||||
| } | |||||
| } | |||||
| // Test WaitMsg is not returned if the channel closes abruptly. | |||||
| func TestExitWithoutStatusOrSignal(t *testing.T) { | |||||
| conn := dial(exitWithoutSignalOrStatus, t) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("Unable to request new session: %v", err) | |||||
| } | |||||
| defer session.Close() | |||||
| if err := session.Shell(); err != nil { | |||||
| t.Fatalf("Unable to execute command: %v", err) | |||||
| } | |||||
| err = session.Wait() | |||||
| if err == nil { | |||||
| t.Fatalf("expected command to fail but it didn't") | |||||
| } | |||||
| _, ok := err.(*ExitError) | |||||
| if ok { | |||||
| // you can't actually test for errors.errorString | |||||
| // because it's not exported. | |||||
| t.Fatalf("expected *errorString but got %T", err) | |||||
| } | |||||
| } | |||||
| // windowTestBytes is the number of bytes that we'll send to the SSH server. | |||||
| const windowTestBytes = 16000 * 200 | |||||
| // TestServerWindow writes random data to the server. The server is expected to echo | |||||
| // the same data back, which is compared against the original. | |||||
| func TestServerWindow(t *testing.T) { | |||||
| origBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) | |||||
| io.CopyN(origBuf, crypto_rand.Reader, windowTestBytes) | |||||
| origBytes := origBuf.Bytes() | |||||
| conn := dial(echoHandler, t) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatal(err) | |||||
| } | |||||
| defer session.Close() | |||||
| result := make(chan []byte) | |||||
| go func() { | |||||
| defer close(result) | |||||
| echoedBuf := bytes.NewBuffer(make([]byte, 0, windowTestBytes)) | |||||
| serverStdout, err := session.StdoutPipe() | |||||
| if err != nil { | |||||
| t.Errorf("StdoutPipe failed: %v", err) | |||||
| return | |||||
| } | |||||
| n, err := copyNRandomly("stdout", echoedBuf, serverStdout, windowTestBytes) | |||||
| if err != nil && err != io.EOF { | |||||
| t.Errorf("Read only %d bytes from server, expected %d: %v", n, windowTestBytes, err) | |||||
| } | |||||
| result <- echoedBuf.Bytes() | |||||
| }() | |||||
| serverStdin, err := session.StdinPipe() | |||||
| if err != nil { | |||||
| t.Fatalf("StdinPipe failed: %v", err) | |||||
| } | |||||
| written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestBytes) | |||||
| if err != nil { | |||||
| t.Fatalf("failed to copy origBuf to serverStdin: %v", err) | |||||
| } | |||||
| if written != windowTestBytes { | |||||
| t.Fatalf("Wrote only %d of %d bytes to server", written, windowTestBytes) | |||||
| } | |||||
| echoedBytes := <-result | |||||
| if !bytes.Equal(origBytes, echoedBytes) { | |||||
| t.Fatalf("Echoed buffer differed from original, orig %d, echoed %d", len(origBytes), len(echoedBytes)) | |||||
| } | |||||
| } | |||||
| // Verify the client can handle a keepalive packet from the server. | |||||
| func TestClientHandlesKeepalives(t *testing.T) { | |||||
| conn := dial(channelKeepaliveSender, t) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatal(err) | |||||
| } | |||||
| defer session.Close() | |||||
| if err := session.Shell(); err != nil { | |||||
| t.Fatalf("Unable to execute command: %v", err) | |||||
| } | |||||
| err = session.Wait() | |||||
| if err != nil { | |||||
| t.Fatalf("expected nil but got: %v", err) | |||||
| } | |||||
| } | |||||
| type exitStatusMsg struct { | |||||
| Status uint32 | |||||
| } | |||||
| type exitSignalMsg struct { | |||||
| Signal string | |||||
| CoreDumped bool | |||||
| Errmsg string | |||||
| Lang string | |||||
| } | |||||
| func handleTerminalRequests(in <-chan *Request) { | |||||
| for req := range in { | |||||
| ok := false | |||||
| switch req.Type { | |||||
| case "shell": | |||||
| ok = true | |||||
| if len(req.Payload) > 0 { | |||||
| // We don't accept any commands, only the default shell. | |||||
| ok = false | |||||
| } | |||||
| case "env": | |||||
| ok = true | |||||
| } | |||||
| req.Reply(ok, nil) | |||||
| } | |||||
| } | |||||
| func newServerShell(ch Channel, in <-chan *Request, prompt string) *terminal.Terminal { | |||||
| term := terminal.NewTerminal(ch, prompt) | |||||
| go handleTerminalRequests(in) | |||||
| return term | |||||
| } | |||||
| func exitStatusZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { | |||||
| defer ch.Close() | |||||
| // this string is returned to stdout | |||||
| shell := newServerShell(ch, in, "> ") | |||||
| readLine(shell, t) | |||||
| sendStatus(0, ch, t) | |||||
| } | |||||
| func exitStatusNonZeroHandler(ch Channel, in <-chan *Request, t *testing.T) { | |||||
| defer ch.Close() | |||||
| shell := newServerShell(ch, in, "> ") | |||||
| readLine(shell, t) | |||||
| sendStatus(15, ch, t) | |||||
| } | |||||
| func exitSignalAndStatusHandler(ch Channel, in <-chan *Request, t *testing.T) { | |||||
| defer ch.Close() | |||||
| shell := newServerShell(ch, in, "> ") | |||||
| readLine(shell, t) | |||||
| sendStatus(15, ch, t) | |||||
| sendSignal("TERM", ch, t) | |||||
| } | |||||
| func exitSignalHandler(ch Channel, in <-chan *Request, t *testing.T) { | |||||
| defer ch.Close() | |||||
| shell := newServerShell(ch, in, "> ") | |||||
| readLine(shell, t) | |||||
| sendSignal("TERM", ch, t) | |||||
| } | |||||
| func exitSignalUnknownHandler(ch Channel, in <-chan *Request, t *testing.T) { | |||||
| defer ch.Close() | |||||
| shell := newServerShell(ch, in, "> ") | |||||
| readLine(shell, t) | |||||
| sendSignal("SYS", ch, t) | |||||
| } | |||||
| func exitWithoutSignalOrStatus(ch Channel, in <-chan *Request, t *testing.T) { | |||||
| defer ch.Close() | |||||
| shell := newServerShell(ch, in, "> ") | |||||
| readLine(shell, t) | |||||
| } | |||||
| func shellHandler(ch Channel, in <-chan *Request, t *testing.T) { | |||||
| defer ch.Close() | |||||
| // this string is returned to stdout | |||||
| shell := newServerShell(ch, in, "golang") | |||||
| readLine(shell, t) | |||||
| sendStatus(0, ch, t) | |||||
| } | |||||
| // Ignores the command, writes fixed strings to stderr and stdout. | |||||
| // Strings are "this-is-stdout." and "this-is-stderr.". | |||||
| func fixedOutputHandler(ch Channel, in <-chan *Request, t *testing.T) { | |||||
| defer ch.Close() | |||||
| _, err := ch.Read(nil) | |||||
| req, ok := <-in | |||||
| if !ok { | |||||
| t.Fatalf("error: expected channel request, got: %#v", err) | |||||
| return | |||||
| } | |||||
| // ignore request, always send some text | |||||
| req.Reply(true, nil) | |||||
| _, err = io.WriteString(ch, "this-is-stdout.") | |||||
| if err != nil { | |||||
| t.Fatalf("error writing on server: %v", err) | |||||
| } | |||||
| _, err = io.WriteString(ch.Stderr(), "this-is-stderr.") | |||||
| if err != nil { | |||||
| t.Fatalf("error writing on server: %v", err) | |||||
| } | |||||
| sendStatus(0, ch, t) | |||||
| } | |||||
| func readLine(shell *terminal.Terminal, t *testing.T) { | |||||
| if _, err := shell.ReadLine(); err != nil && err != io.EOF { | |||||
| t.Errorf("unable to read line: %v", err) | |||||
| } | |||||
| } | |||||
| func sendStatus(status uint32, ch Channel, t *testing.T) { | |||||
| msg := exitStatusMsg{ | |||||
| Status: status, | |||||
| } | |||||
| if _, err := ch.SendRequest("exit-status", false, Marshal(&msg)); err != nil { | |||||
| t.Errorf("unable to send status: %v", err) | |||||
| } | |||||
| } | |||||
| func sendSignal(signal string, ch Channel, t *testing.T) { | |||||
| sig := exitSignalMsg{ | |||||
| Signal: signal, | |||||
| CoreDumped: false, | |||||
| Errmsg: "Process terminated", | |||||
| Lang: "en-GB-oed", | |||||
| } | |||||
| if _, err := ch.SendRequest("exit-signal", false, Marshal(&sig)); err != nil { | |||||
| t.Errorf("unable to send signal: %v", err) | |||||
| } | |||||
| } | |||||
| func discardHandler(ch Channel, t *testing.T) { | |||||
| defer ch.Close() | |||||
| io.Copy(ioutil.Discard, ch) | |||||
| } | |||||
| func echoHandler(ch Channel, in <-chan *Request, t *testing.T) { | |||||
| defer ch.Close() | |||||
| if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err != nil { | |||||
| t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTestBytes, err) | |||||
| } | |||||
| } | |||||
| // copyNRandomly copies n bytes from src to dst. It uses a variable, and random, | |||||
| // buffer size to exercise more code paths. | |||||
| func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, error) { | |||||
| var ( | |||||
| buf = make([]byte, 32*1024) | |||||
| written int | |||||
| remaining = n | |||||
| ) | |||||
| for remaining > 0 { | |||||
| l := rand.Intn(1 << 15) | |||||
| if remaining < l { | |||||
| l = remaining | |||||
| } | |||||
| nr, er := src.Read(buf[:l]) | |||||
| nw, ew := dst.Write(buf[:nr]) | |||||
| remaining -= nw | |||||
| written += nw | |||||
| if ew != nil { | |||||
| return written, ew | |||||
| } | |||||
| if nr != nw { | |||||
| return written, io.ErrShortWrite | |||||
| } | |||||
| if er != nil && er != io.EOF { | |||||
| return written, er | |||||
| } | |||||
| } | |||||
| return written, nil | |||||
| } | |||||
| func channelKeepaliveSender(ch Channel, in <-chan *Request, t *testing.T) { | |||||
| defer ch.Close() | |||||
| shell := newServerShell(ch, in, "> ") | |||||
| readLine(shell, t) | |||||
| if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err != nil { | |||||
| t.Errorf("unable to send channel keepalive request: %v", err) | |||||
| } | |||||
| sendStatus(0, ch, t) | |||||
| } | |||||
| func TestClientWriteEOF(t *testing.T) { | |||||
| conn := dial(simpleEchoHandler, t) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatal(err) | |||||
| } | |||||
| defer session.Close() | |||||
| stdin, err := session.StdinPipe() | |||||
| if err != nil { | |||||
| t.Fatalf("StdinPipe failed: %v", err) | |||||
| } | |||||
| stdout, err := session.StdoutPipe() | |||||
| if err != nil { | |||||
| t.Fatalf("StdoutPipe failed: %v", err) | |||||
| } | |||||
| data := []byte(`0000`) | |||||
| _, err = stdin.Write(data) | |||||
| if err != nil { | |||||
| t.Fatalf("Write failed: %v", err) | |||||
| } | |||||
| stdin.Close() | |||||
| res, err := ioutil.ReadAll(stdout) | |||||
| if err != nil { | |||||
| t.Fatalf("Read failed: %v", err) | |||||
| } | |||||
| if !bytes.Equal(data, res) { | |||||
| t.Fatalf("Read differed from write, wrote: %v, read: %v", data, res) | |||||
| } | |||||
| } | |||||
| func simpleEchoHandler(ch Channel, in <-chan *Request, t *testing.T) { | |||||
| defer ch.Close() | |||||
| data, err := ioutil.ReadAll(ch) | |||||
| if err != nil { | |||||
| t.Errorf("handler read error: %v", err) | |||||
| } | |||||
| _, err = ch.Write(data) | |||||
| if err != nil { | |||||
| t.Errorf("handler write error: %v", err) | |||||
| } | |||||
| } | |||||
| func TestSessionID(t *testing.T) { | |||||
| c1, c2, err := netPipe() | |||||
| if err != nil { | |||||
| t.Fatalf("netPipe: %v", err) | |||||
| } | |||||
| defer c1.Close() | |||||
| defer c2.Close() | |||||
| serverID := make(chan []byte, 1) | |||||
| clientID := make(chan []byte, 1) | |||||
| serverConf := &ServerConfig{ | |||||
| NoClientAuth: true, | |||||
| } | |||||
| serverConf.AddHostKey(testSigners["ecdsa"]) | |||||
| clientConf := &ClientConfig{ | |||||
| User: "user", | |||||
| } | |||||
| go func() { | |||||
| conn, chans, reqs, err := NewServerConn(c1, serverConf) | |||||
| if err != nil { | |||||
| t.Fatalf("server handshake: %v", err) | |||||
| } | |||||
| serverID <- conn.SessionID() | |||||
| go DiscardRequests(reqs) | |||||
| for ch := range chans { | |||||
| ch.Reject(Prohibited, "") | |||||
| } | |||||
| }() | |||||
| go func() { | |||||
| conn, chans, reqs, err := NewClientConn(c2, "", clientConf) | |||||
| if err != nil { | |||||
| t.Fatalf("client handshake: %v", err) | |||||
| } | |||||
| clientID <- conn.SessionID() | |||||
| go DiscardRequests(reqs) | |||||
| for ch := range chans { | |||||
| ch.Reject(Prohibited, "") | |||||
| } | |||||
| }() | |||||
| s := <-serverID | |||||
| c := <-clientID | |||||
| if bytes.Compare(s, c) != 0 { | |||||
| t.Errorf("server session ID (%x) != client session ID (%x)", s, c) | |||||
| } else if len(s) == 0 { | |||||
| t.Errorf("client and server SessionID were empty.") | |||||
| } | |||||
| } | |||||
| type noReadConn struct { | |||||
| readSeen bool | |||||
| net.Conn | |||||
| } | |||||
| func (c *noReadConn) Close() error { | |||||
| return nil | |||||
| } | |||||
| func (c *noReadConn) Read(b []byte) (int, error) { | |||||
| c.readSeen = true | |||||
| return 0, errors.New("noReadConn error") | |||||
| } | |||||
| func TestInvalidServerConfiguration(t *testing.T) { | |||||
| c1, c2, err := netPipe() | |||||
| if err != nil { | |||||
| t.Fatalf("netPipe: %v", err) | |||||
| } | |||||
| defer c1.Close() | |||||
| defer c2.Close() | |||||
| serveConn := noReadConn{Conn: c1} | |||||
| serverConf := &ServerConfig{} | |||||
| NewServerConn(&serveConn, serverConf) | |||||
| if serveConn.readSeen { | |||||
| t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing host key") | |||||
| } | |||||
| serverConf.AddHostKey(testSigners["ecdsa"]) | |||||
| NewServerConn(&serveConn, serverConf) | |||||
| if serveConn.readSeen { | |||||
| t.Fatalf("NewServerConn attempted to Read() from Conn while configuration is missing authentication method") | |||||
| } | |||||
| } | |||||
| func TestHostKeyAlgorithms(t *testing.T) { | |||||
| serverConf := &ServerConfig{ | |||||
| NoClientAuth: true, | |||||
| } | |||||
| serverConf.AddHostKey(testSigners["rsa"]) | |||||
| serverConf.AddHostKey(testSigners["ecdsa"]) | |||||
| connect := func(clientConf *ClientConfig, want string) { | |||||
| var alg string | |||||
| clientConf.HostKeyCallback = func(h string, a net.Addr, key PublicKey) error { | |||||
| alg = key.Type() | |||||
| return nil | |||||
| } | |||||
| c1, c2, err := netPipe() | |||||
| if err != nil { | |||||
| t.Fatalf("netPipe: %v", err) | |||||
| } | |||||
| defer c1.Close() | |||||
| defer c2.Close() | |||||
| go NewServerConn(c1, serverConf) | |||||
| _, _, _, err = NewClientConn(c2, "", clientConf) | |||||
| if err != nil { | |||||
| t.Fatalf("NewClientConn: %v", err) | |||||
| } | |||||
| if alg != want { | |||||
| t.Errorf("selected key algorithm %s, want %s", alg, want) | |||||
| } | |||||
| } | |||||
| // By default, we get the preferred algorithm, which is ECDSA 256. | |||||
| clientConf := &ClientConfig{} | |||||
| connect(clientConf, KeyAlgoECDSA256) | |||||
| // Client asks for RSA explicitly. | |||||
| clientConf.HostKeyAlgorithms = []string{KeyAlgoRSA} | |||||
| connect(clientConf, KeyAlgoRSA) | |||||
| c1, c2, err := netPipe() | |||||
| if err != nil { | |||||
| t.Fatalf("netPipe: %v", err) | |||||
| } | |||||
| defer c1.Close() | |||||
| defer c2.Close() | |||||
| go NewServerConn(c1, serverConf) | |||||
| clientConf.HostKeyAlgorithms = []string{"nonexistent-hostkey-algo"} | |||||
| _, _, _, err = NewClientConn(c2, "", clientConf) | |||||
| if err == nil { | |||||
| t.Fatal("succeeded connecting with unknown hostkey algorithm") | |||||
| } | |||||
| } | |||||
| @@ -1,407 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "errors" | |||||
| "fmt" | |||||
| "io" | |||||
| "math/rand" | |||||
| "net" | |||||
| "strconv" | |||||
| "strings" | |||||
| "sync" | |||||
| "time" | |||||
| ) | |||||
| // Listen requests the remote peer open a listening socket on | |||||
| // addr. Incoming connections will be available by calling Accept on | |||||
| // the returned net.Listener. The listener must be serviced, or the | |||||
| // SSH connection may hang. | |||||
| func (c *Client) Listen(n, addr string) (net.Listener, error) { | |||||
| laddr, err := net.ResolveTCPAddr(n, addr) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return c.ListenTCP(laddr) | |||||
| } | |||||
| // Automatic port allocation is broken with OpenSSH before 6.0. See | |||||
| // also https://bugzilla.mindrot.org/show_bug.cgi?id=2017. In | |||||
| // particular, OpenSSH 5.9 sends a channelOpenMsg with port number 0, | |||||
| // rather than the actual port number. This means you can never open | |||||
| // two different listeners with auto allocated ports. We work around | |||||
| // this by trying explicit ports until we succeed. | |||||
| const openSSHPrefix = "OpenSSH_" | |||||
| var portRandomizer = rand.New(rand.NewSource(time.Now().UnixNano())) | |||||
| // isBrokenOpenSSHVersion returns true if the given version string | |||||
| // specifies a version of OpenSSH that is known to have a bug in port | |||||
| // forwarding. | |||||
| func isBrokenOpenSSHVersion(versionStr string) bool { | |||||
| i := strings.Index(versionStr, openSSHPrefix) | |||||
| if i < 0 { | |||||
| return false | |||||
| } | |||||
| i += len(openSSHPrefix) | |||||
| j := i | |||||
| for ; j < len(versionStr); j++ { | |||||
| if versionStr[j] < '0' || versionStr[j] > '9' { | |||||
| break | |||||
| } | |||||
| } | |||||
| version, _ := strconv.Atoi(versionStr[i:j]) | |||||
| return version < 6 | |||||
| } | |||||
| // autoPortListenWorkaround simulates automatic port allocation by | |||||
| // trying random ports repeatedly. | |||||
| func (c *Client) autoPortListenWorkaround(laddr *net.TCPAddr) (net.Listener, error) { | |||||
| var sshListener net.Listener | |||||
| var err error | |||||
| const tries = 10 | |||||
| for i := 0; i < tries; i++ { | |||||
| addr := *laddr | |||||
| addr.Port = 1024 + portRandomizer.Intn(60000) | |||||
| sshListener, err = c.ListenTCP(&addr) | |||||
| if err == nil { | |||||
| laddr.Port = addr.Port | |||||
| return sshListener, err | |||||
| } | |||||
| } | |||||
| return nil, fmt.Errorf("ssh: listen on random port failed after %d tries: %v", tries, err) | |||||
| } | |||||
| // RFC 4254 7.1 | |||||
| type channelForwardMsg struct { | |||||
| addr string | |||||
| rport uint32 | |||||
| } | |||||
| // ListenTCP requests the remote peer open a listening socket | |||||
| // on laddr. Incoming connections will be available by calling | |||||
| // Accept on the returned net.Listener. | |||||
| func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) { | |||||
| if laddr.Port == 0 && isBrokenOpenSSHVersion(string(c.ServerVersion())) { | |||||
| return c.autoPortListenWorkaround(laddr) | |||||
| } | |||||
| m := channelForwardMsg{ | |||||
| laddr.IP.String(), | |||||
| uint32(laddr.Port), | |||||
| } | |||||
| // send message | |||||
| ok, resp, err := c.SendRequest("tcpip-forward", true, Marshal(&m)) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| if !ok { | |||||
| return nil, errors.New("ssh: tcpip-forward request denied by peer") | |||||
| } | |||||
| // If the original port was 0, then the remote side will | |||||
| // supply a real port number in the response. | |||||
| if laddr.Port == 0 { | |||||
| var p struct { | |||||
| Port uint32 | |||||
| } | |||||
| if err := Unmarshal(resp, &p); err != nil { | |||||
| return nil, err | |||||
| } | |||||
| laddr.Port = int(p.Port) | |||||
| } | |||||
| // Register this forward, using the port number we obtained. | |||||
| ch := c.forwards.add(*laddr) | |||||
| return &tcpListener{laddr, c, ch}, nil | |||||
| } | |||||
| // forwardList stores a mapping between remote | |||||
| // forward requests and the tcpListeners. | |||||
| type forwardList struct { | |||||
| sync.Mutex | |||||
| entries []forwardEntry | |||||
| } | |||||
| // forwardEntry represents an established mapping of a laddr on a | |||||
| // remote ssh server to a channel connected to a tcpListener. | |||||
| type forwardEntry struct { | |||||
| laddr net.TCPAddr | |||||
| c chan forward | |||||
| } | |||||
| // forward represents an incoming forwarded tcpip connection. The | |||||
| // arguments to add/remove/lookup should be address as specified in | |||||
| // the original forward-request. | |||||
| type forward struct { | |||||
| newCh NewChannel // the ssh client channel underlying this forward | |||||
| raddr *net.TCPAddr // the raddr of the incoming connection | |||||
| } | |||||
| func (l *forwardList) add(addr net.TCPAddr) chan forward { | |||||
| l.Lock() | |||||
| defer l.Unlock() | |||||
| f := forwardEntry{ | |||||
| addr, | |||||
| make(chan forward, 1), | |||||
| } | |||||
| l.entries = append(l.entries, f) | |||||
| return f.c | |||||
| } | |||||
| // See RFC 4254, section 7.2 | |||||
| type forwardedTCPPayload struct { | |||||
| Addr string | |||||
| Port uint32 | |||||
| OriginAddr string | |||||
| OriginPort uint32 | |||||
| } | |||||
| // parseTCPAddr parses the originating address from the remote into a *net.TCPAddr. | |||||
| func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) { | |||||
| if port == 0 || port > 65535 { | |||||
| return nil, fmt.Errorf("ssh: port number out of range: %d", port) | |||||
| } | |||||
| ip := net.ParseIP(string(addr)) | |||||
| if ip == nil { | |||||
| return nil, fmt.Errorf("ssh: cannot parse IP address %q", addr) | |||||
| } | |||||
| return &net.TCPAddr{IP: ip, Port: int(port)}, nil | |||||
| } | |||||
| func (l *forwardList) handleChannels(in <-chan NewChannel) { | |||||
| for ch := range in { | |||||
| var payload forwardedTCPPayload | |||||
| if err := Unmarshal(ch.ExtraData(), &payload); err != nil { | |||||
| ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error()) | |||||
| continue | |||||
| } | |||||
| // RFC 4254 section 7.2 specifies that incoming | |||||
| // addresses should list the address, in string | |||||
| // format. It is implied that this should be an IP | |||||
| // address, as it would be impossible to connect to it | |||||
| // otherwise. | |||||
| laddr, err := parseTCPAddr(payload.Addr, payload.Port) | |||||
| if err != nil { | |||||
| ch.Reject(ConnectionFailed, err.Error()) | |||||
| continue | |||||
| } | |||||
| raddr, err := parseTCPAddr(payload.OriginAddr, payload.OriginPort) | |||||
| if err != nil { | |||||
| ch.Reject(ConnectionFailed, err.Error()) | |||||
| continue | |||||
| } | |||||
| if ok := l.forward(*laddr, *raddr, ch); !ok { | |||||
| // Section 7.2, implementations MUST reject spurious incoming | |||||
| // connections. | |||||
| ch.Reject(Prohibited, "no forward for address") | |||||
| continue | |||||
| } | |||||
| } | |||||
| } | |||||
| // remove removes the forward entry, and the channel feeding its | |||||
| // listener. | |||||
| func (l *forwardList) remove(addr net.TCPAddr) { | |||||
| l.Lock() | |||||
| defer l.Unlock() | |||||
| for i, f := range l.entries { | |||||
| if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port { | |||||
| l.entries = append(l.entries[:i], l.entries[i+1:]...) | |||||
| close(f.c) | |||||
| return | |||||
| } | |||||
| } | |||||
| } | |||||
| // closeAll closes and clears all forwards. | |||||
| func (l *forwardList) closeAll() { | |||||
| l.Lock() | |||||
| defer l.Unlock() | |||||
| for _, f := range l.entries { | |||||
| close(f.c) | |||||
| } | |||||
| l.entries = nil | |||||
| } | |||||
| func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool { | |||||
| l.Lock() | |||||
| defer l.Unlock() | |||||
| for _, f := range l.entries { | |||||
| if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port { | |||||
| f.c <- forward{ch, &raddr} | |||||
| return true | |||||
| } | |||||
| } | |||||
| return false | |||||
| } | |||||
| type tcpListener struct { | |||||
| laddr *net.TCPAddr | |||||
| conn *Client | |||||
| in <-chan forward | |||||
| } | |||||
| // Accept waits for and returns the next connection to the listener. | |||||
| func (l *tcpListener) Accept() (net.Conn, error) { | |||||
| s, ok := <-l.in | |||||
| if !ok { | |||||
| return nil, io.EOF | |||||
| } | |||||
| ch, incoming, err := s.newCh.Accept() | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| go DiscardRequests(incoming) | |||||
| return &tcpChanConn{ | |||||
| Channel: ch, | |||||
| laddr: l.laddr, | |||||
| raddr: s.raddr, | |||||
| }, nil | |||||
| } | |||||
| // Close closes the listener. | |||||
| func (l *tcpListener) Close() error { | |||||
| m := channelForwardMsg{ | |||||
| l.laddr.IP.String(), | |||||
| uint32(l.laddr.Port), | |||||
| } | |||||
| // this also closes the listener. | |||||
| l.conn.forwards.remove(*l.laddr) | |||||
| ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m)) | |||||
| if err == nil && !ok { | |||||
| err = errors.New("ssh: cancel-tcpip-forward failed") | |||||
| } | |||||
| return err | |||||
| } | |||||
| // Addr returns the listener's network address. | |||||
| func (l *tcpListener) Addr() net.Addr { | |||||
| return l.laddr | |||||
| } | |||||
| // Dial initiates a connection to the addr from the remote host. | |||||
| // The resulting connection has a zero LocalAddr() and RemoteAddr(). | |||||
| func (c *Client) Dial(n, addr string) (net.Conn, error) { | |||||
| // Parse the address into host and numeric port. | |||||
| host, portString, err := net.SplitHostPort(addr) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| port, err := strconv.ParseUint(portString, 10, 16) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| // Use a zero address for local and remote address. | |||||
| zeroAddr := &net.TCPAddr{ | |||||
| IP: net.IPv4zero, | |||||
| Port: 0, | |||||
| } | |||||
| ch, err := c.dial(net.IPv4zero.String(), 0, host, int(port)) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return &tcpChanConn{ | |||||
| Channel: ch, | |||||
| laddr: zeroAddr, | |||||
| raddr: zeroAddr, | |||||
| }, nil | |||||
| } | |||||
| // DialTCP connects to the remote address raddr on the network net, | |||||
| // which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is used | |||||
| // as the local address for the connection. | |||||
| func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) { | |||||
| if laddr == nil { | |||||
| laddr = &net.TCPAddr{ | |||||
| IP: net.IPv4zero, | |||||
| Port: 0, | |||||
| } | |||||
| } | |||||
| ch, err := c.dial(laddr.IP.String(), laddr.Port, raddr.IP.String(), raddr.Port) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return &tcpChanConn{ | |||||
| Channel: ch, | |||||
| laddr: laddr, | |||||
| raddr: raddr, | |||||
| }, nil | |||||
| } | |||||
| // RFC 4254 7.2 | |||||
| type channelOpenDirectMsg struct { | |||||
| raddr string | |||||
| rport uint32 | |||||
| laddr string | |||||
| lport uint32 | |||||
| } | |||||
| func (c *Client) dial(laddr string, lport int, raddr string, rport int) (Channel, error) { | |||||
| msg := channelOpenDirectMsg{ | |||||
| raddr: raddr, | |||||
| rport: uint32(rport), | |||||
| laddr: laddr, | |||||
| lport: uint32(lport), | |||||
| } | |||||
| ch, in, err := c.OpenChannel("direct-tcpip", Marshal(&msg)) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| go DiscardRequests(in) | |||||
| return ch, err | |||||
| } | |||||
| type tcpChan struct { | |||||
| Channel // the backing channel | |||||
| } | |||||
| // tcpChanConn fulfills the net.Conn interface without | |||||
| // the tcpChan having to hold laddr or raddr directly. | |||||
| type tcpChanConn struct { | |||||
| Channel | |||||
| laddr, raddr net.Addr | |||||
| } | |||||
| // LocalAddr returns the local network address. | |||||
| func (t *tcpChanConn) LocalAddr() net.Addr { | |||||
| return t.laddr | |||||
| } | |||||
| // RemoteAddr returns the remote network address. | |||||
| func (t *tcpChanConn) RemoteAddr() net.Addr { | |||||
| return t.raddr | |||||
| } | |||||
| // SetDeadline sets the read and write deadlines associated | |||||
| // with the connection. | |||||
| func (t *tcpChanConn) SetDeadline(deadline time.Time) error { | |||||
| if err := t.SetReadDeadline(deadline); err != nil { | |||||
| return err | |||||
| } | |||||
| return t.SetWriteDeadline(deadline) | |||||
| } | |||||
| // SetReadDeadline sets the read deadline. | |||||
| // A zero value for t means Read will not time out. | |||||
| // After the deadline, the error from Read will implement net.Error | |||||
| // with Timeout() == true. | |||||
| func (t *tcpChanConn) SetReadDeadline(deadline time.Time) error { | |||||
| return errors.New("ssh: tcpChan: deadline not supported") | |||||
| } | |||||
| // SetWriteDeadline exists to satisfy the net.Conn interface | |||||
| // but is not implemented by this type. It always returns an error. | |||||
| func (t *tcpChanConn) SetWriteDeadline(deadline time.Time) error { | |||||
| return errors.New("ssh: tcpChan: deadline not supported") | |||||
| } | |||||
| @@ -1,20 +0,0 @@ | |||||
| // Copyright 2014 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "testing" | |||||
| ) | |||||
| func TestAutoPortListenBroken(t *testing.T) { | |||||
| broken := "SSH-2.0-OpenSSH_5.9hh11" | |||||
| works := "SSH-2.0-OpenSSH_6.1" | |||||
| if !isBrokenOpenSSHVersion(broken) { | |||||
| t.Errorf("version %q not marked as broken", broken) | |||||
| } | |||||
| if isBrokenOpenSSHVersion(works) { | |||||
| t.Errorf("version %q marked as broken", works) | |||||
| } | |||||
| } | |||||
| @@ -1,892 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package terminal | |||||
| import ( | |||||
| "bytes" | |||||
| "io" | |||||
| "sync" | |||||
| "unicode/utf8" | |||||
| ) | |||||
| // EscapeCodes contains escape sequences that can be written to the terminal in | |||||
| // order to achieve different styles of text. | |||||
| type EscapeCodes struct { | |||||
| // Foreground colors | |||||
| Black, Red, Green, Yellow, Blue, Magenta, Cyan, White []byte | |||||
| // Reset all attributes | |||||
| Reset []byte | |||||
| } | |||||
| var vt100EscapeCodes = EscapeCodes{ | |||||
| Black: []byte{keyEscape, '[', '3', '0', 'm'}, | |||||
| Red: []byte{keyEscape, '[', '3', '1', 'm'}, | |||||
| Green: []byte{keyEscape, '[', '3', '2', 'm'}, | |||||
| Yellow: []byte{keyEscape, '[', '3', '3', 'm'}, | |||||
| Blue: []byte{keyEscape, '[', '3', '4', 'm'}, | |||||
| Magenta: []byte{keyEscape, '[', '3', '5', 'm'}, | |||||
| Cyan: []byte{keyEscape, '[', '3', '6', 'm'}, | |||||
| White: []byte{keyEscape, '[', '3', '7', 'm'}, | |||||
| Reset: []byte{keyEscape, '[', '0', 'm'}, | |||||
| } | |||||
| // Terminal contains the state for running a VT100 terminal that is capable of | |||||
| // reading lines of input. | |||||
| type Terminal struct { | |||||
| // AutoCompleteCallback, if non-null, is called for each keypress with | |||||
| // the full input line and the current position of the cursor (in | |||||
| // bytes, as an index into |line|). If it returns ok=false, the key | |||||
| // press is processed normally. Otherwise it returns a replacement line | |||||
| // and the new cursor position. | |||||
| AutoCompleteCallback func(line string, pos int, key rune) (newLine string, newPos int, ok bool) | |||||
| // Escape contains a pointer to the escape codes for this terminal. | |||||
| // It's always a valid pointer, although the escape codes themselves | |||||
| // may be empty if the terminal doesn't support them. | |||||
| Escape *EscapeCodes | |||||
| // lock protects the terminal and the state in this object from | |||||
| // concurrent processing of a key press and a Write() call. | |||||
| lock sync.Mutex | |||||
| c io.ReadWriter | |||||
| prompt []rune | |||||
| // line is the current line being entered. | |||||
| line []rune | |||||
| // pos is the logical position of the cursor in line | |||||
| pos int | |||||
| // echo is true if local echo is enabled | |||||
| echo bool | |||||
| // pasteActive is true iff there is a bracketed paste operation in | |||||
| // progress. | |||||
| pasteActive bool | |||||
| // cursorX contains the current X value of the cursor where the left | |||||
| // edge is 0. cursorY contains the row number where the first row of | |||||
| // the current line is 0. | |||||
| cursorX, cursorY int | |||||
| // maxLine is the greatest value of cursorY so far. | |||||
| maxLine int | |||||
| termWidth, termHeight int | |||||
| // outBuf contains the terminal data to be sent. | |||||
| outBuf []byte | |||||
| // remainder contains the remainder of any partial key sequences after | |||||
| // a read. It aliases into inBuf. | |||||
| remainder []byte | |||||
| inBuf [256]byte | |||||
| // history contains previously entered commands so that they can be | |||||
| // accessed with the up and down keys. | |||||
| history stRingBuffer | |||||
| // historyIndex stores the currently accessed history entry, where zero | |||||
| // means the immediately previous entry. | |||||
| historyIndex int | |||||
| // When navigating up and down the history it's possible to return to | |||||
| // the incomplete, initial line. That value is stored in | |||||
| // historyPending. | |||||
| historyPending string | |||||
| } | |||||
| // NewTerminal runs a VT100 terminal on the given ReadWriter. If the ReadWriter is | |||||
| // a local terminal, that terminal must first have been put into raw mode. | |||||
| // prompt is a string that is written at the start of each input line (i.e. | |||||
| // "> "). | |||||
| func NewTerminal(c io.ReadWriter, prompt string) *Terminal { | |||||
| return &Terminal{ | |||||
| Escape: &vt100EscapeCodes, | |||||
| c: c, | |||||
| prompt: []rune(prompt), | |||||
| termWidth: 80, | |||||
| termHeight: 24, | |||||
| echo: true, | |||||
| historyIndex: -1, | |||||
| } | |||||
| } | |||||
| const ( | |||||
| keyCtrlD = 4 | |||||
| keyCtrlU = 21 | |||||
| keyEnter = '\r' | |||||
| keyEscape = 27 | |||||
| keyBackspace = 127 | |||||
| keyUnknown = 0xd800 /* UTF-16 surrogate area */ + iota | |||||
| keyUp | |||||
| keyDown | |||||
| keyLeft | |||||
| keyRight | |||||
| keyAltLeft | |||||
| keyAltRight | |||||
| keyHome | |||||
| keyEnd | |||||
| keyDeleteWord | |||||
| keyDeleteLine | |||||
| keyClearScreen | |||||
| keyPasteStart | |||||
| keyPasteEnd | |||||
| ) | |||||
| var pasteStart = []byte{keyEscape, '[', '2', '0', '0', '~'} | |||||
| var pasteEnd = []byte{keyEscape, '[', '2', '0', '1', '~'} | |||||
| // bytesToKey tries to parse a key sequence from b. If successful, it returns | |||||
| // the key and the remainder of the input. Otherwise it returns utf8.RuneError. | |||||
| func bytesToKey(b []byte, pasteActive bool) (rune, []byte) { | |||||
| if len(b) == 0 { | |||||
| return utf8.RuneError, nil | |||||
| } | |||||
| if !pasteActive { | |||||
| switch b[0] { | |||||
| case 1: // ^A | |||||
| return keyHome, b[1:] | |||||
| case 5: // ^E | |||||
| return keyEnd, b[1:] | |||||
| case 8: // ^H | |||||
| return keyBackspace, b[1:] | |||||
| case 11: // ^K | |||||
| return keyDeleteLine, b[1:] | |||||
| case 12: // ^L | |||||
| return keyClearScreen, b[1:] | |||||
| case 23: // ^W | |||||
| return keyDeleteWord, b[1:] | |||||
| } | |||||
| } | |||||
| if b[0] != keyEscape { | |||||
| if !utf8.FullRune(b) { | |||||
| return utf8.RuneError, b | |||||
| } | |||||
| r, l := utf8.DecodeRune(b) | |||||
| return r, b[l:] | |||||
| } | |||||
| if !pasteActive && len(b) >= 3 && b[0] == keyEscape && b[1] == '[' { | |||||
| switch b[2] { | |||||
| case 'A': | |||||
| return keyUp, b[3:] | |||||
| case 'B': | |||||
| return keyDown, b[3:] | |||||
| case 'C': | |||||
| return keyRight, b[3:] | |||||
| case 'D': | |||||
| return keyLeft, b[3:] | |||||
| case 'H': | |||||
| return keyHome, b[3:] | |||||
| case 'F': | |||||
| return keyEnd, b[3:] | |||||
| } | |||||
| } | |||||
| if !pasteActive && len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' { | |||||
| switch b[5] { | |||||
| case 'C': | |||||
| return keyAltRight, b[6:] | |||||
| case 'D': | |||||
| return keyAltLeft, b[6:] | |||||
| } | |||||
| } | |||||
| if !pasteActive && len(b) >= 6 && bytes.Equal(b[:6], pasteStart) { | |||||
| return keyPasteStart, b[6:] | |||||
| } | |||||
| if pasteActive && len(b) >= 6 && bytes.Equal(b[:6], pasteEnd) { | |||||
| return keyPasteEnd, b[6:] | |||||
| } | |||||
| // If we get here then we have a key that we don't recognise, or a | |||||
| // partial sequence. It's not clear how one should find the end of a | |||||
| // sequence without knowing them all, but it seems that [a-zA-Z~] only | |||||
| // appears at the end of a sequence. | |||||
| for i, c := range b[0:] { | |||||
| if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c == '~' { | |||||
| return keyUnknown, b[i+1:] | |||||
| } | |||||
| } | |||||
| return utf8.RuneError, b | |||||
| } | |||||
| // queue appends data to the end of t.outBuf | |||||
| func (t *Terminal) queue(data []rune) { | |||||
| t.outBuf = append(t.outBuf, []byte(string(data))...) | |||||
| } | |||||
| var eraseUnderCursor = []rune{' ', keyEscape, '[', 'D'} | |||||
| var space = []rune{' '} | |||||
| func isPrintable(key rune) bool { | |||||
| isInSurrogateArea := key >= 0xd800 && key <= 0xdbff | |||||
| return key >= 32 && !isInSurrogateArea | |||||
| } | |||||
| // moveCursorToPos appends data to t.outBuf which will move the cursor to the | |||||
| // given, logical position in the text. | |||||
| func (t *Terminal) moveCursorToPos(pos int) { | |||||
| if !t.echo { | |||||
| return | |||||
| } | |||||
| x := visualLength(t.prompt) + pos | |||||
| y := x / t.termWidth | |||||
| x = x % t.termWidth | |||||
| up := 0 | |||||
| if y < t.cursorY { | |||||
| up = t.cursorY - y | |||||
| } | |||||
| down := 0 | |||||
| if y > t.cursorY { | |||||
| down = y - t.cursorY | |||||
| } | |||||
| left := 0 | |||||
| if x < t.cursorX { | |||||
| left = t.cursorX - x | |||||
| } | |||||
| right := 0 | |||||
| if x > t.cursorX { | |||||
| right = x - t.cursorX | |||||
| } | |||||
| t.cursorX = x | |||||
| t.cursorY = y | |||||
| t.move(up, down, left, right) | |||||
| } | |||||
| func (t *Terminal) move(up, down, left, right int) { | |||||
| movement := make([]rune, 3*(up+down+left+right)) | |||||
| m := movement | |||||
| for i := 0; i < up; i++ { | |||||
| m[0] = keyEscape | |||||
| m[1] = '[' | |||||
| m[2] = 'A' | |||||
| m = m[3:] | |||||
| } | |||||
| for i := 0; i < down; i++ { | |||||
| m[0] = keyEscape | |||||
| m[1] = '[' | |||||
| m[2] = 'B' | |||||
| m = m[3:] | |||||
| } | |||||
| for i := 0; i < left; i++ { | |||||
| m[0] = keyEscape | |||||
| m[1] = '[' | |||||
| m[2] = 'D' | |||||
| m = m[3:] | |||||
| } | |||||
| for i := 0; i < right; i++ { | |||||
| m[0] = keyEscape | |||||
| m[1] = '[' | |||||
| m[2] = 'C' | |||||
| m = m[3:] | |||||
| } | |||||
| t.queue(movement) | |||||
| } | |||||
| func (t *Terminal) clearLineToRight() { | |||||
| op := []rune{keyEscape, '[', 'K'} | |||||
| t.queue(op) | |||||
| } | |||||
| const maxLineLength = 4096 | |||||
| func (t *Terminal) setLine(newLine []rune, newPos int) { | |||||
| if t.echo { | |||||
| t.moveCursorToPos(0) | |||||
| t.writeLine(newLine) | |||||
| for i := len(newLine); i < len(t.line); i++ { | |||||
| t.writeLine(space) | |||||
| } | |||||
| t.moveCursorToPos(newPos) | |||||
| } | |||||
| t.line = newLine | |||||
| t.pos = newPos | |||||
| } | |||||
| func (t *Terminal) advanceCursor(places int) { | |||||
| t.cursorX += places | |||||
| t.cursorY += t.cursorX / t.termWidth | |||||
| if t.cursorY > t.maxLine { | |||||
| t.maxLine = t.cursorY | |||||
| } | |||||
| t.cursorX = t.cursorX % t.termWidth | |||||
| if places > 0 && t.cursorX == 0 { | |||||
| // Normally terminals will advance the current position | |||||
| // when writing a character. But that doesn't happen | |||||
| // for the last character in a line. However, when | |||||
| // writing a character (except a new line) that causes | |||||
| // a line wrap, the position will be advanced two | |||||
| // places. | |||||
| // | |||||
| // So, if we are stopping at the end of a line, we | |||||
| // need to write a newline so that our cursor can be | |||||
| // advanced to the next line. | |||||
| t.outBuf = append(t.outBuf, '\n') | |||||
| } | |||||
| } | |||||
| func (t *Terminal) eraseNPreviousChars(n int) { | |||||
| if n == 0 { | |||||
| return | |||||
| } | |||||
| if t.pos < n { | |||||
| n = t.pos | |||||
| } | |||||
| t.pos -= n | |||||
| t.moveCursorToPos(t.pos) | |||||
| copy(t.line[t.pos:], t.line[n+t.pos:]) | |||||
| t.line = t.line[:len(t.line)-n] | |||||
| if t.echo { | |||||
| t.writeLine(t.line[t.pos:]) | |||||
| for i := 0; i < n; i++ { | |||||
| t.queue(space) | |||||
| } | |||||
| t.advanceCursor(n) | |||||
| t.moveCursorToPos(t.pos) | |||||
| } | |||||
| } | |||||
| // countToLeftWord returns then number of characters from the cursor to the | |||||
| // start of the previous word. | |||||
| func (t *Terminal) countToLeftWord() int { | |||||
| if t.pos == 0 { | |||||
| return 0 | |||||
| } | |||||
| pos := t.pos - 1 | |||||
| for pos > 0 { | |||||
| if t.line[pos] != ' ' { | |||||
| break | |||||
| } | |||||
| pos-- | |||||
| } | |||||
| for pos > 0 { | |||||
| if t.line[pos] == ' ' { | |||||
| pos++ | |||||
| break | |||||
| } | |||||
| pos-- | |||||
| } | |||||
| return t.pos - pos | |||||
| } | |||||
| // countToRightWord returns then number of characters from the cursor to the | |||||
| // start of the next word. | |||||
| func (t *Terminal) countToRightWord() int { | |||||
| pos := t.pos | |||||
| for pos < len(t.line) { | |||||
| if t.line[pos] == ' ' { | |||||
| break | |||||
| } | |||||
| pos++ | |||||
| } | |||||
| for pos < len(t.line) { | |||||
| if t.line[pos] != ' ' { | |||||
| break | |||||
| } | |||||
| pos++ | |||||
| } | |||||
| return pos - t.pos | |||||
| } | |||||
| // visualLength returns the number of visible glyphs in s. | |||||
| func visualLength(runes []rune) int { | |||||
| inEscapeSeq := false | |||||
| length := 0 | |||||
| for _, r := range runes { | |||||
| switch { | |||||
| case inEscapeSeq: | |||||
| if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') { | |||||
| inEscapeSeq = false | |||||
| } | |||||
| case r == '\x1b': | |||||
| inEscapeSeq = true | |||||
| default: | |||||
| length++ | |||||
| } | |||||
| } | |||||
| return length | |||||
| } | |||||
| // handleKey processes the given key and, optionally, returns a line of text | |||||
| // that the user has entered. | |||||
| func (t *Terminal) handleKey(key rune) (line string, ok bool) { | |||||
| if t.pasteActive && key != keyEnter { | |||||
| t.addKeyToLine(key) | |||||
| return | |||||
| } | |||||
| switch key { | |||||
| case keyBackspace: | |||||
| if t.pos == 0 { | |||||
| return | |||||
| } | |||||
| t.eraseNPreviousChars(1) | |||||
| case keyAltLeft: | |||||
| // move left by a word. | |||||
| t.pos -= t.countToLeftWord() | |||||
| t.moveCursorToPos(t.pos) | |||||
| case keyAltRight: | |||||
| // move right by a word. | |||||
| t.pos += t.countToRightWord() | |||||
| t.moveCursorToPos(t.pos) | |||||
| case keyLeft: | |||||
| if t.pos == 0 { | |||||
| return | |||||
| } | |||||
| t.pos-- | |||||
| t.moveCursorToPos(t.pos) | |||||
| case keyRight: | |||||
| if t.pos == len(t.line) { | |||||
| return | |||||
| } | |||||
| t.pos++ | |||||
| t.moveCursorToPos(t.pos) | |||||
| case keyHome: | |||||
| if t.pos == 0 { | |||||
| return | |||||
| } | |||||
| t.pos = 0 | |||||
| t.moveCursorToPos(t.pos) | |||||
| case keyEnd: | |||||
| if t.pos == len(t.line) { | |||||
| return | |||||
| } | |||||
| t.pos = len(t.line) | |||||
| t.moveCursorToPos(t.pos) | |||||
| case keyUp: | |||||
| entry, ok := t.history.NthPreviousEntry(t.historyIndex + 1) | |||||
| if !ok { | |||||
| return "", false | |||||
| } | |||||
| if t.historyIndex == -1 { | |||||
| t.historyPending = string(t.line) | |||||
| } | |||||
| t.historyIndex++ | |||||
| runes := []rune(entry) | |||||
| t.setLine(runes, len(runes)) | |||||
| case keyDown: | |||||
| switch t.historyIndex { | |||||
| case -1: | |||||
| return | |||||
| case 0: | |||||
| runes := []rune(t.historyPending) | |||||
| t.setLine(runes, len(runes)) | |||||
| t.historyIndex-- | |||||
| default: | |||||
| entry, ok := t.history.NthPreviousEntry(t.historyIndex - 1) | |||||
| if ok { | |||||
| t.historyIndex-- | |||||
| runes := []rune(entry) | |||||
| t.setLine(runes, len(runes)) | |||||
| } | |||||
| } | |||||
| case keyEnter: | |||||
| t.moveCursorToPos(len(t.line)) | |||||
| t.queue([]rune("\r\n")) | |||||
| line = string(t.line) | |||||
| ok = true | |||||
| t.line = t.line[:0] | |||||
| t.pos = 0 | |||||
| t.cursorX = 0 | |||||
| t.cursorY = 0 | |||||
| t.maxLine = 0 | |||||
| case keyDeleteWord: | |||||
| // Delete zero or more spaces and then one or more characters. | |||||
| t.eraseNPreviousChars(t.countToLeftWord()) | |||||
| case keyDeleteLine: | |||||
| // Delete everything from the current cursor position to the | |||||
| // end of line. | |||||
| for i := t.pos; i < len(t.line); i++ { | |||||
| t.queue(space) | |||||
| t.advanceCursor(1) | |||||
| } | |||||
| t.line = t.line[:t.pos] | |||||
| t.moveCursorToPos(t.pos) | |||||
| case keyCtrlD: | |||||
| // Erase the character under the current position. | |||||
| // The EOF case when the line is empty is handled in | |||||
| // readLine(). | |||||
| if t.pos < len(t.line) { | |||||
| t.pos++ | |||||
| t.eraseNPreviousChars(1) | |||||
| } | |||||
| case keyCtrlU: | |||||
| t.eraseNPreviousChars(t.pos) | |||||
| case keyClearScreen: | |||||
| // Erases the screen and moves the cursor to the home position. | |||||
| t.queue([]rune("\x1b[2J\x1b[H")) | |||||
| t.queue(t.prompt) | |||||
| t.cursorX, t.cursorY = 0, 0 | |||||
| t.advanceCursor(visualLength(t.prompt)) | |||||
| t.setLine(t.line, t.pos) | |||||
| default: | |||||
| if t.AutoCompleteCallback != nil { | |||||
| prefix := string(t.line[:t.pos]) | |||||
| suffix := string(t.line[t.pos:]) | |||||
| t.lock.Unlock() | |||||
| newLine, newPos, completeOk := t.AutoCompleteCallback(prefix+suffix, len(prefix), key) | |||||
| t.lock.Lock() | |||||
| if completeOk { | |||||
| t.setLine([]rune(newLine), utf8.RuneCount([]byte(newLine)[:newPos])) | |||||
| return | |||||
| } | |||||
| } | |||||
| if !isPrintable(key) { | |||||
| return | |||||
| } | |||||
| if len(t.line) == maxLineLength { | |||||
| return | |||||
| } | |||||
| t.addKeyToLine(key) | |||||
| } | |||||
| return | |||||
| } | |||||
| // addKeyToLine inserts the given key at the current position in the current | |||||
| // line. | |||||
| func (t *Terminal) addKeyToLine(key rune) { | |||||
| if len(t.line) == cap(t.line) { | |||||
| newLine := make([]rune, len(t.line), 2*(1+len(t.line))) | |||||
| copy(newLine, t.line) | |||||
| t.line = newLine | |||||
| } | |||||
| t.line = t.line[:len(t.line)+1] | |||||
| copy(t.line[t.pos+1:], t.line[t.pos:]) | |||||
| t.line[t.pos] = key | |||||
| if t.echo { | |||||
| t.writeLine(t.line[t.pos:]) | |||||
| } | |||||
| t.pos++ | |||||
| t.moveCursorToPos(t.pos) | |||||
| } | |||||
| func (t *Terminal) writeLine(line []rune) { | |||||
| for len(line) != 0 { | |||||
| remainingOnLine := t.termWidth - t.cursorX | |||||
| todo := len(line) | |||||
| if todo > remainingOnLine { | |||||
| todo = remainingOnLine | |||||
| } | |||||
| t.queue(line[:todo]) | |||||
| t.advanceCursor(visualLength(line[:todo])) | |||||
| line = line[todo:] | |||||
| } | |||||
| } | |||||
| func (t *Terminal) Write(buf []byte) (n int, err error) { | |||||
| t.lock.Lock() | |||||
| defer t.lock.Unlock() | |||||
| if t.cursorX == 0 && t.cursorY == 0 { | |||||
| // This is the easy case: there's nothing on the screen that we | |||||
| // have to move out of the way. | |||||
| return t.c.Write(buf) | |||||
| } | |||||
| // We have a prompt and possibly user input on the screen. We | |||||
| // have to clear it first. | |||||
| t.move(0 /* up */, 0 /* down */, t.cursorX /* left */, 0 /* right */) | |||||
| t.cursorX = 0 | |||||
| t.clearLineToRight() | |||||
| for t.cursorY > 0 { | |||||
| t.move(1 /* up */, 0, 0, 0) | |||||
| t.cursorY-- | |||||
| t.clearLineToRight() | |||||
| } | |||||
| if _, err = t.c.Write(t.outBuf); err != nil { | |||||
| return | |||||
| } | |||||
| t.outBuf = t.outBuf[:0] | |||||
| if n, err = t.c.Write(buf); err != nil { | |||||
| return | |||||
| } | |||||
| t.writeLine(t.prompt) | |||||
| if t.echo { | |||||
| t.writeLine(t.line) | |||||
| } | |||||
| t.moveCursorToPos(t.pos) | |||||
| if _, err = t.c.Write(t.outBuf); err != nil { | |||||
| return | |||||
| } | |||||
| t.outBuf = t.outBuf[:0] | |||||
| return | |||||
| } | |||||
| // ReadPassword temporarily changes the prompt and reads a password, without | |||||
| // echo, from the terminal. | |||||
| func (t *Terminal) ReadPassword(prompt string) (line string, err error) { | |||||
| t.lock.Lock() | |||||
| defer t.lock.Unlock() | |||||
| oldPrompt := t.prompt | |||||
| t.prompt = []rune(prompt) | |||||
| t.echo = false | |||||
| line, err = t.readLine() | |||||
| t.prompt = oldPrompt | |||||
| t.echo = true | |||||
| return | |||||
| } | |||||
| // ReadLine returns a line of input from the terminal. | |||||
| func (t *Terminal) ReadLine() (line string, err error) { | |||||
| t.lock.Lock() | |||||
| defer t.lock.Unlock() | |||||
| return t.readLine() | |||||
| } | |||||
| func (t *Terminal) readLine() (line string, err error) { | |||||
| // t.lock must be held at this point | |||||
| if t.cursorX == 0 && t.cursorY == 0 { | |||||
| t.writeLine(t.prompt) | |||||
| t.c.Write(t.outBuf) | |||||
| t.outBuf = t.outBuf[:0] | |||||
| } | |||||
| lineIsPasted := t.pasteActive | |||||
| for { | |||||
| rest := t.remainder | |||||
| lineOk := false | |||||
| for !lineOk { | |||||
| var key rune | |||||
| key, rest = bytesToKey(rest, t.pasteActive) | |||||
| if key == utf8.RuneError { | |||||
| break | |||||
| } | |||||
| if !t.pasteActive { | |||||
| if key == keyCtrlD { | |||||
| if len(t.line) == 0 { | |||||
| return "", io.EOF | |||||
| } | |||||
| } | |||||
| if key == keyPasteStart { | |||||
| t.pasteActive = true | |||||
| if len(t.line) == 0 { | |||||
| lineIsPasted = true | |||||
| } | |||||
| continue | |||||
| } | |||||
| } else if key == keyPasteEnd { | |||||
| t.pasteActive = false | |||||
| continue | |||||
| } | |||||
| if !t.pasteActive { | |||||
| lineIsPasted = false | |||||
| } | |||||
| line, lineOk = t.handleKey(key) | |||||
| } | |||||
| if len(rest) > 0 { | |||||
| n := copy(t.inBuf[:], rest) | |||||
| t.remainder = t.inBuf[:n] | |||||
| } else { | |||||
| t.remainder = nil | |||||
| } | |||||
| t.c.Write(t.outBuf) | |||||
| t.outBuf = t.outBuf[:0] | |||||
| if lineOk { | |||||
| if t.echo { | |||||
| t.historyIndex = -1 | |||||
| t.history.Add(line) | |||||
| } | |||||
| if lineIsPasted { | |||||
| err = ErrPasteIndicator | |||||
| } | |||||
| return | |||||
| } | |||||
| // t.remainder is a slice at the beginning of t.inBuf | |||||
| // containing a partial key sequence | |||||
| readBuf := t.inBuf[len(t.remainder):] | |||||
| var n int | |||||
| t.lock.Unlock() | |||||
| n, err = t.c.Read(readBuf) | |||||
| t.lock.Lock() | |||||
| if err != nil { | |||||
| return | |||||
| } | |||||
| t.remainder = t.inBuf[:n+len(t.remainder)] | |||||
| } | |||||
| panic("unreachable") // for Go 1.0. | |||||
| } | |||||
| // SetPrompt sets the prompt to be used when reading subsequent lines. | |||||
| func (t *Terminal) SetPrompt(prompt string) { | |||||
| t.lock.Lock() | |||||
| defer t.lock.Unlock() | |||||
| t.prompt = []rune(prompt) | |||||
| } | |||||
| func (t *Terminal) clearAndRepaintLinePlusNPrevious(numPrevLines int) { | |||||
| // Move cursor to column zero at the start of the line. | |||||
| t.move(t.cursorY, 0, t.cursorX, 0) | |||||
| t.cursorX, t.cursorY = 0, 0 | |||||
| t.clearLineToRight() | |||||
| for t.cursorY < numPrevLines { | |||||
| // Move down a line | |||||
| t.move(0, 1, 0, 0) | |||||
| t.cursorY++ | |||||
| t.clearLineToRight() | |||||
| } | |||||
| // Move back to beginning. | |||||
| t.move(t.cursorY, 0, 0, 0) | |||||
| t.cursorX, t.cursorY = 0, 0 | |||||
| t.queue(t.prompt) | |||||
| t.advanceCursor(visualLength(t.prompt)) | |||||
| t.writeLine(t.line) | |||||
| t.moveCursorToPos(t.pos) | |||||
| } | |||||
| func (t *Terminal) SetSize(width, height int) error { | |||||
| t.lock.Lock() | |||||
| defer t.lock.Unlock() | |||||
| if width == 0 { | |||||
| width = 1 | |||||
| } | |||||
| oldWidth := t.termWidth | |||||
| t.termWidth, t.termHeight = width, height | |||||
| switch { | |||||
| case width == oldWidth: | |||||
| // If the width didn't change then nothing else needs to be | |||||
| // done. | |||||
| return nil | |||||
| case len(t.line) == 0 && t.cursorX == 0 && t.cursorY == 0: | |||||
| // If there is nothing on current line and no prompt printed, | |||||
| // just do nothing | |||||
| return nil | |||||
| case width < oldWidth: | |||||
| // Some terminals (e.g. xterm) will truncate lines that were | |||||
| // too long when shinking. Others, (e.g. gnome-terminal) will | |||||
| // attempt to wrap them. For the former, repainting t.maxLine | |||||
| // works great, but that behaviour goes badly wrong in the case | |||||
| // of the latter because they have doubled every full line. | |||||
| // We assume that we are working on a terminal that wraps lines | |||||
| // and adjust the cursor position based on every previous line | |||||
| // wrapping and turning into two. This causes the prompt on | |||||
| // xterms to move upwards, which isn't great, but it avoids a | |||||
| // huge mess with gnome-terminal. | |||||
| if t.cursorX >= t.termWidth { | |||||
| t.cursorX = t.termWidth - 1 | |||||
| } | |||||
| t.cursorY *= 2 | |||||
| t.clearAndRepaintLinePlusNPrevious(t.maxLine * 2) | |||||
| case width > oldWidth: | |||||
| // If the terminal expands then our position calculations will | |||||
| // be wrong in the future because we think the cursor is | |||||
| // |t.pos| chars into the string, but there will be a gap at | |||||
| // the end of any wrapped line. | |||||
| // | |||||
| // But the position will actually be correct until we move, so | |||||
| // we can move back to the beginning and repaint everything. | |||||
| t.clearAndRepaintLinePlusNPrevious(t.maxLine) | |||||
| } | |||||
| _, err := t.c.Write(t.outBuf) | |||||
| t.outBuf = t.outBuf[:0] | |||||
| return err | |||||
| } | |||||
| type pasteIndicatorError struct{} | |||||
| func (pasteIndicatorError) Error() string { | |||||
| return "terminal: ErrPasteIndicator not correctly handled" | |||||
| } | |||||
| // ErrPasteIndicator may be returned from ReadLine as the error, in addition | |||||
| // to valid line data. It indicates that bracketed paste mode is enabled and | |||||
| // that the returned line consists only of pasted data. Programs may wish to | |||||
| // interpret pasted data more literally than typed data. | |||||
| var ErrPasteIndicator = pasteIndicatorError{} | |||||
| // SetBracketedPasteMode requests that the terminal bracket paste operations | |||||
| // with markers. Not all terminals support this but, if it is supported, then | |||||
| // enabling this mode will stop any autocomplete callback from running due to | |||||
| // pastes. Additionally, any lines that are completely pasted will be returned | |||||
| // from ReadLine with the error set to ErrPasteIndicator. | |||||
| func (t *Terminal) SetBracketedPasteMode(on bool) { | |||||
| if on { | |||||
| io.WriteString(t.c, "\x1b[?2004h") | |||||
| } else { | |||||
| io.WriteString(t.c, "\x1b[?2004l") | |||||
| } | |||||
| } | |||||
| // stRingBuffer is a ring buffer of strings. | |||||
| type stRingBuffer struct { | |||||
| // entries contains max elements. | |||||
| entries []string | |||||
| max int | |||||
| // head contains the index of the element most recently added to the ring. | |||||
| head int | |||||
| // size contains the number of elements in the ring. | |||||
| size int | |||||
| } | |||||
| func (s *stRingBuffer) Add(a string) { | |||||
| if s.entries == nil { | |||||
| const defaultNumEntries = 100 | |||||
| s.entries = make([]string, defaultNumEntries) | |||||
| s.max = defaultNumEntries | |||||
| } | |||||
| s.head = (s.head + 1) % s.max | |||||
| s.entries[s.head] = a | |||||
| if s.size < s.max { | |||||
| s.size++ | |||||
| } | |||||
| } | |||||
| // NthPreviousEntry returns the value passed to the nth previous call to Add. | |||||
| // If n is zero then the immediately prior value is returned, if one, then the | |||||
| // next most recent, and so on. If such an element doesn't exist then ok is | |||||
| // false. | |||||
| func (s *stRingBuffer) NthPreviousEntry(n int) (value string, ok bool) { | |||||
| if n >= s.size { | |||||
| return "", false | |||||
| } | |||||
| index := s.head - n | |||||
| if index < 0 { | |||||
| index += s.max | |||||
| } | |||||
| return s.entries[index], true | |||||
| } | |||||
| @@ -1,269 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package terminal | |||||
| import ( | |||||
| "io" | |||||
| "testing" | |||||
| ) | |||||
| type MockTerminal struct { | |||||
| toSend []byte | |||||
| bytesPerRead int | |||||
| received []byte | |||||
| } | |||||
| func (c *MockTerminal) Read(data []byte) (n int, err error) { | |||||
| n = len(data) | |||||
| if n == 0 { | |||||
| return | |||||
| } | |||||
| if n > len(c.toSend) { | |||||
| n = len(c.toSend) | |||||
| } | |||||
| if n == 0 { | |||||
| return 0, io.EOF | |||||
| } | |||||
| if c.bytesPerRead > 0 && n > c.bytesPerRead { | |||||
| n = c.bytesPerRead | |||||
| } | |||||
| copy(data, c.toSend[:n]) | |||||
| c.toSend = c.toSend[n:] | |||||
| return | |||||
| } | |||||
| func (c *MockTerminal) Write(data []byte) (n int, err error) { | |||||
| c.received = append(c.received, data...) | |||||
| return len(data), nil | |||||
| } | |||||
| func TestClose(t *testing.T) { | |||||
| c := &MockTerminal{} | |||||
| ss := NewTerminal(c, "> ") | |||||
| line, err := ss.ReadLine() | |||||
| if line != "" { | |||||
| t.Errorf("Expected empty line but got: %s", line) | |||||
| } | |||||
| if err != io.EOF { | |||||
| t.Errorf("Error should have been EOF but got: %s", err) | |||||
| } | |||||
| } | |||||
| var keyPressTests = []struct { | |||||
| in string | |||||
| line string | |||||
| err error | |||||
| throwAwayLines int | |||||
| }{ | |||||
| { | |||||
| err: io.EOF, | |||||
| }, | |||||
| { | |||||
| in: "\r", | |||||
| line: "", | |||||
| }, | |||||
| { | |||||
| in: "foo\r", | |||||
| line: "foo", | |||||
| }, | |||||
| { | |||||
| in: "a\x1b[Cb\r", // right | |||||
| line: "ab", | |||||
| }, | |||||
| { | |||||
| in: "a\x1b[Db\r", // left | |||||
| line: "ba", | |||||
| }, | |||||
| { | |||||
| in: "a\177b\r", // backspace | |||||
| line: "b", | |||||
| }, | |||||
| { | |||||
| in: "\x1b[A\r", // up | |||||
| }, | |||||
| { | |||||
| in: "\x1b[B\r", // down | |||||
| }, | |||||
| { | |||||
| in: "line\x1b[A\x1b[B\r", // up then down | |||||
| line: "line", | |||||
| }, | |||||
| { | |||||
| in: "line1\rline2\x1b[A\r", // recall previous line. | |||||
| line: "line1", | |||||
| throwAwayLines: 1, | |||||
| }, | |||||
| { | |||||
| // recall two previous lines and append. | |||||
| in: "line1\rline2\rline3\x1b[A\x1b[Axxx\r", | |||||
| line: "line1xxx", | |||||
| throwAwayLines: 2, | |||||
| }, | |||||
| { | |||||
| // Ctrl-A to move to beginning of line followed by ^K to kill | |||||
| // line. | |||||
| in: "a b \001\013\r", | |||||
| line: "", | |||||
| }, | |||||
| { | |||||
| // Ctrl-A to move to beginning of line, Ctrl-E to move to end, | |||||
| // finally ^K to kill nothing. | |||||
| in: "a b \001\005\013\r", | |||||
| line: "a b ", | |||||
| }, | |||||
| { | |||||
| in: "\027\r", | |||||
| line: "", | |||||
| }, | |||||
| { | |||||
| in: "a\027\r", | |||||
| line: "", | |||||
| }, | |||||
| { | |||||
| in: "a \027\r", | |||||
| line: "", | |||||
| }, | |||||
| { | |||||
| in: "a b\027\r", | |||||
| line: "a ", | |||||
| }, | |||||
| { | |||||
| in: "a b \027\r", | |||||
| line: "a ", | |||||
| }, | |||||
| { | |||||
| in: "one two thr\x1b[D\027\r", | |||||
| line: "one two r", | |||||
| }, | |||||
| { | |||||
| in: "\013\r", | |||||
| line: "", | |||||
| }, | |||||
| { | |||||
| in: "a\013\r", | |||||
| line: "a", | |||||
| }, | |||||
| { | |||||
| in: "ab\x1b[D\013\r", | |||||
| line: "a", | |||||
| }, | |||||
| { | |||||
| in: "Ξεσκεπάζω\r", | |||||
| line: "Ξεσκεπάζω", | |||||
| }, | |||||
| { | |||||
| in: "£\r\x1b[A\177\r", // non-ASCII char, enter, up, backspace. | |||||
| line: "", | |||||
| throwAwayLines: 1, | |||||
| }, | |||||
| { | |||||
| in: "£\r££\x1b[A\x1b[B\177\r", // non-ASCII char, enter, 2x non-ASCII, up, down, backspace, enter. | |||||
| line: "£", | |||||
| throwAwayLines: 1, | |||||
| }, | |||||
| { | |||||
| // Ctrl-D at the end of the line should be ignored. | |||||
| in: "a\004\r", | |||||
| line: "a", | |||||
| }, | |||||
| { | |||||
| // a, b, left, Ctrl-D should erase the b. | |||||
| in: "ab\x1b[D\004\r", | |||||
| line: "a", | |||||
| }, | |||||
| { | |||||
| // a, b, c, d, left, left, ^U should erase to the beginning of | |||||
| // the line. | |||||
| in: "abcd\x1b[D\x1b[D\025\r", | |||||
| line: "cd", | |||||
| }, | |||||
| { | |||||
| // Bracketed paste mode: control sequences should be returned | |||||
| // verbatim in paste mode. | |||||
| in: "abc\x1b[200~de\177f\x1b[201~\177\r", | |||||
| line: "abcde\177", | |||||
| }, | |||||
| { | |||||
| // Enter in bracketed paste mode should still work. | |||||
| in: "abc\x1b[200~d\refg\x1b[201~h\r", | |||||
| line: "efgh", | |||||
| throwAwayLines: 1, | |||||
| }, | |||||
| { | |||||
| // Lines consisting entirely of pasted data should be indicated as such. | |||||
| in: "\x1b[200~a\r", | |||||
| line: "a", | |||||
| err: ErrPasteIndicator, | |||||
| }, | |||||
| } | |||||
| func TestKeyPresses(t *testing.T) { | |||||
| for i, test := range keyPressTests { | |||||
| for j := 1; j < len(test.in); j++ { | |||||
| c := &MockTerminal{ | |||||
| toSend: []byte(test.in), | |||||
| bytesPerRead: j, | |||||
| } | |||||
| ss := NewTerminal(c, "> ") | |||||
| for k := 0; k < test.throwAwayLines; k++ { | |||||
| _, err := ss.ReadLine() | |||||
| if err != nil { | |||||
| t.Errorf("Throwaway line %d from test %d resulted in error: %s", k, i, err) | |||||
| } | |||||
| } | |||||
| line, err := ss.ReadLine() | |||||
| if line != test.line { | |||||
| t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line) | |||||
| break | |||||
| } | |||||
| if err != test.err { | |||||
| t.Errorf("Error resulting from test %d (%d bytes per read) was '%v', expected '%v'", i, j, err, test.err) | |||||
| break | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| func TestPasswordNotSaved(t *testing.T) { | |||||
| c := &MockTerminal{ | |||||
| toSend: []byte("password\r\x1b[A\r"), | |||||
| bytesPerRead: 1, | |||||
| } | |||||
| ss := NewTerminal(c, "> ") | |||||
| pw, _ := ss.ReadPassword("> ") | |||||
| if pw != "password" { | |||||
| t.Fatalf("failed to read password, got %s", pw) | |||||
| } | |||||
| line, _ := ss.ReadLine() | |||||
| if len(line) > 0 { | |||||
| t.Fatalf("password was saved in history") | |||||
| } | |||||
| } | |||||
| var setSizeTests = []struct { | |||||
| width, height int | |||||
| }{ | |||||
| {40, 13}, | |||||
| {80, 24}, | |||||
| {132, 43}, | |||||
| } | |||||
| func TestTerminalSetSize(t *testing.T) { | |||||
| for _, setSize := range setSizeTests { | |||||
| c := &MockTerminal{ | |||||
| toSend: []byte("password\r\x1b[A\r"), | |||||
| bytesPerRead: 1, | |||||
| } | |||||
| ss := NewTerminal(c, "> ") | |||||
| ss.SetSize(setSize.width, setSize.height) | |||||
| pw, _ := ss.ReadPassword("Password: ") | |||||
| if pw != "password" { | |||||
| t.Fatalf("failed to read password, got %s", pw) | |||||
| } | |||||
| if string(c.received) != "Password: \r\n" { | |||||
| t.Errorf("failed to set the temporary prompt expected %q, got %q", "Password: ", c.received) | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,128 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| // +build darwin dragonfly freebsd linux,!appengine netbsd openbsd | |||||
| // Package terminal provides support functions for dealing with terminals, as | |||||
| // commonly found on UNIX systems. | |||||
| // | |||||
| // Putting a terminal into raw mode is the most common requirement: | |||||
| // | |||||
| // oldState, err := terminal.MakeRaw(0) | |||||
| // if err != nil { | |||||
| // panic(err) | |||||
| // } | |||||
| // defer terminal.Restore(0, oldState) | |||||
| package terminal | |||||
| import ( | |||||
| "io" | |||||
| "syscall" | |||||
| "unsafe" | |||||
| ) | |||||
| // State contains the state of a terminal. | |||||
| type State struct { | |||||
| termios syscall.Termios | |||||
| } | |||||
| // IsTerminal returns true if the given file descriptor is a terminal. | |||||
| func IsTerminal(fd int) bool { | |||||
| var termios syscall.Termios | |||||
| _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&termios)), 0, 0, 0) | |||||
| return err == 0 | |||||
| } | |||||
| // MakeRaw put the terminal connected to the given file descriptor into raw | |||||
| // mode and returns the previous state of the terminal so that it can be | |||||
| // restored. | |||||
| func MakeRaw(fd int) (*State, error) { | |||||
| var oldState State | |||||
| if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 { | |||||
| return nil, err | |||||
| } | |||||
| newState := oldState.termios | |||||
| newState.Iflag &^= syscall.ISTRIP | syscall.INLCR | syscall.ICRNL | syscall.IGNCR | syscall.IXON | syscall.IXOFF | |||||
| newState.Lflag &^= syscall.ECHO | syscall.ICANON | syscall.ISIG | |||||
| if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 { | |||||
| return nil, err | |||||
| } | |||||
| return &oldState, nil | |||||
| } | |||||
| // GetState returns the current state of a terminal which may be useful to | |||||
| // restore the terminal after a signal. | |||||
| func GetState(fd int) (*State, error) { | |||||
| var oldState State | |||||
| if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState.termios)), 0, 0, 0); err != 0 { | |||||
| return nil, err | |||||
| } | |||||
| return &oldState, nil | |||||
| } | |||||
| // Restore restores the terminal connected to the given file descriptor to a | |||||
| // previous state. | |||||
| func Restore(fd int, state *State) error { | |||||
| _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&state.termios)), 0, 0, 0) | |||||
| return err | |||||
| } | |||||
| // GetSize returns the dimensions of the given terminal. | |||||
| func GetSize(fd int) (width, height int, err error) { | |||||
| var dimensions [4]uint16 | |||||
| if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), uintptr(syscall.TIOCGWINSZ), uintptr(unsafe.Pointer(&dimensions)), 0, 0, 0); err != 0 { | |||||
| return -1, -1, err | |||||
| } | |||||
| return int(dimensions[1]), int(dimensions[0]), nil | |||||
| } | |||||
| // ReadPassword reads a line of input from a terminal without local echo. This | |||||
| // is commonly used for inputting passwords and other sensitive data. The slice | |||||
| // returned does not include the \n. | |||||
| func ReadPassword(fd int) ([]byte, error) { | |||||
| var oldState syscall.Termios | |||||
| if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlReadTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0); err != 0 { | |||||
| return nil, err | |||||
| } | |||||
| newState := oldState | |||||
| newState.Lflag &^= syscall.ECHO | |||||
| newState.Lflag |= syscall.ICANON | syscall.ISIG | |||||
| newState.Iflag |= syscall.ICRNL | |||||
| if _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&newState)), 0, 0, 0); err != 0 { | |||||
| return nil, err | |||||
| } | |||||
| defer func() { | |||||
| syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), ioctlWriteTermios, uintptr(unsafe.Pointer(&oldState)), 0, 0, 0) | |||||
| }() | |||||
| var buf [16]byte | |||||
| var ret []byte | |||||
| for { | |||||
| n, err := syscall.Read(fd, buf[:]) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| if n == 0 { | |||||
| if len(ret) == 0 { | |||||
| return nil, io.EOF | |||||
| } | |||||
| break | |||||
| } | |||||
| if buf[n-1] == '\n' { | |||||
| n-- | |||||
| } | |||||
| ret = append(ret, buf[:n]...) | |||||
| if n < len(buf) { | |||||
| break | |||||
| } | |||||
| } | |||||
| return ret, nil | |||||
| } | |||||
| @@ -1,12 +0,0 @@ | |||||
| // Copyright 2013 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| // +build darwin dragonfly freebsd netbsd openbsd | |||||
| package terminal | |||||
| import "syscall" | |||||
| const ioctlReadTermios = syscall.TIOCGETA | |||||
| const ioctlWriteTermios = syscall.TIOCSETA | |||||
| @@ -1,11 +0,0 @@ | |||||
| // Copyright 2013 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package terminal | |||||
| // These constants are declared here, rather than importing | |||||
| // them from the syscall package as some syscall packages, even | |||||
| // on linux, for example gccgo, do not declare them. | |||||
| const ioctlReadTermios = 0x5401 // syscall.TCGETS | |||||
| const ioctlWriteTermios = 0x5402 // syscall.TCSETS | |||||
| @@ -1,174 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| // +build windows | |||||
| // Package terminal provides support functions for dealing with terminals, as | |||||
| // commonly found on UNIX systems. | |||||
| // | |||||
| // Putting a terminal into raw mode is the most common requirement: | |||||
| // | |||||
| // oldState, err := terminal.MakeRaw(0) | |||||
| // if err != nil { | |||||
| // panic(err) | |||||
| // } | |||||
| // defer terminal.Restore(0, oldState) | |||||
| package terminal | |||||
| import ( | |||||
| "io" | |||||
| "syscall" | |||||
| "unsafe" | |||||
| ) | |||||
| const ( | |||||
| enableLineInput = 2 | |||||
| enableEchoInput = 4 | |||||
| enableProcessedInput = 1 | |||||
| enableWindowInput = 8 | |||||
| enableMouseInput = 16 | |||||
| enableInsertMode = 32 | |||||
| enableQuickEditMode = 64 | |||||
| enableExtendedFlags = 128 | |||||
| enableAutoPosition = 256 | |||||
| enableProcessedOutput = 1 | |||||
| enableWrapAtEolOutput = 2 | |||||
| ) | |||||
| var kernel32 = syscall.NewLazyDLL("kernel32.dll") | |||||
| var ( | |||||
| procGetConsoleMode = kernel32.NewProc("GetConsoleMode") | |||||
| procSetConsoleMode = kernel32.NewProc("SetConsoleMode") | |||||
| procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo") | |||||
| ) | |||||
| type ( | |||||
| short int16 | |||||
| word uint16 | |||||
| coord struct { | |||||
| x short | |||||
| y short | |||||
| } | |||||
| smallRect struct { | |||||
| left short | |||||
| top short | |||||
| right short | |||||
| bottom short | |||||
| } | |||||
| consoleScreenBufferInfo struct { | |||||
| size coord | |||||
| cursorPosition coord | |||||
| attributes word | |||||
| window smallRect | |||||
| maximumWindowSize coord | |||||
| } | |||||
| ) | |||||
| type State struct { | |||||
| mode uint32 | |||||
| } | |||||
| // IsTerminal returns true if the given file descriptor is a terminal. | |||||
| func IsTerminal(fd int) bool { | |||||
| var st uint32 | |||||
| r, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) | |||||
| return r != 0 && e == 0 | |||||
| } | |||||
| // MakeRaw put the terminal connected to the given file descriptor into raw | |||||
| // mode and returns the previous state of the terminal so that it can be | |||||
| // restored. | |||||
| func MakeRaw(fd int) (*State, error) { | |||||
| var st uint32 | |||||
| _, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) | |||||
| if e != 0 { | |||||
| return nil, error(e) | |||||
| } | |||||
| st &^= (enableEchoInput | enableProcessedInput | enableLineInput | enableProcessedOutput) | |||||
| _, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0) | |||||
| if e != 0 { | |||||
| return nil, error(e) | |||||
| } | |||||
| return &State{st}, nil | |||||
| } | |||||
| // GetState returns the current state of a terminal which may be useful to | |||||
| // restore the terminal after a signal. | |||||
| func GetState(fd int) (*State, error) { | |||||
| var st uint32 | |||||
| _, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) | |||||
| if e != 0 { | |||||
| return nil, error(e) | |||||
| } | |||||
| return &State{st}, nil | |||||
| } | |||||
| // Restore restores the terminal connected to the given file descriptor to a | |||||
| // previous state. | |||||
| func Restore(fd int, state *State) error { | |||||
| _, _, err := syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(state.mode), 0) | |||||
| return err | |||||
| } | |||||
| // GetSize returns the dimensions of the given terminal. | |||||
| func GetSize(fd int) (width, height int, err error) { | |||||
| var info consoleScreenBufferInfo | |||||
| _, _, e := syscall.Syscall(procGetConsoleScreenBufferInfo.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&info)), 0) | |||||
| if e != 0 { | |||||
| return 0, 0, error(e) | |||||
| } | |||||
| return int(info.size.x), int(info.size.y), nil | |||||
| } | |||||
| // ReadPassword reads a line of input from a terminal without local echo. This | |||||
| // is commonly used for inputting passwords and other sensitive data. The slice | |||||
| // returned does not include the \n. | |||||
| func ReadPassword(fd int) ([]byte, error) { | |||||
| var st uint32 | |||||
| _, _, e := syscall.Syscall(procGetConsoleMode.Addr(), 2, uintptr(fd), uintptr(unsafe.Pointer(&st)), 0) | |||||
| if e != 0 { | |||||
| return nil, error(e) | |||||
| } | |||||
| old := st | |||||
| st &^= (enableEchoInput) | |||||
| st |= (enableProcessedInput | enableLineInput | enableProcessedOutput) | |||||
| _, _, e = syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(st), 0) | |||||
| if e != 0 { | |||||
| return nil, error(e) | |||||
| } | |||||
| defer func() { | |||||
| syscall.Syscall(procSetConsoleMode.Addr(), 2, uintptr(fd), uintptr(old), 0) | |||||
| }() | |||||
| var buf [16]byte | |||||
| var ret []byte | |||||
| for { | |||||
| n, err := syscall.Read(syscall.Handle(fd), buf[:]) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| if n == 0 { | |||||
| if len(ret) == 0 { | |||||
| return nil, io.EOF | |||||
| } | |||||
| break | |||||
| } | |||||
| if buf[n-1] == '\n' { | |||||
| n-- | |||||
| } | |||||
| if n > 0 && buf[n-1] == '\r' { | |||||
| n-- | |||||
| } | |||||
| ret = append(ret, buf[:n]...) | |||||
| if n < len(buf) { | |||||
| break | |||||
| } | |||||
| } | |||||
| return ret, nil | |||||
| } | |||||
| @@ -1,59 +0,0 @@ | |||||
| // Copyright 2014 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| // +build darwin dragonfly freebsd linux netbsd openbsd | |||||
| package test | |||||
| import ( | |||||
| "bytes" | |||||
| "testing" | |||||
| "golang.org/x/crypto/ssh" | |||||
| "golang.org/x/crypto/ssh/agent" | |||||
| ) | |||||
| func TestAgentForward(t *testing.T) { | |||||
| server := newServer(t) | |||||
| defer server.Shutdown() | |||||
| conn := server.Dial(clientConfig()) | |||||
| defer conn.Close() | |||||
| keyring := agent.NewKeyring() | |||||
| if err := keyring.Add(agent.AddedKey{PrivateKey: testPrivateKeys["dsa"]}); err != nil { | |||||
| t.Fatalf("Error adding key: %s", err) | |||||
| } | |||||
| if err := keyring.Add(agent.AddedKey{ | |||||
| PrivateKey: testPrivateKeys["dsa"], | |||||
| ConfirmBeforeUse: true, | |||||
| LifetimeSecs: 3600, | |||||
| }); err != nil { | |||||
| t.Fatalf("Error adding key with constraints: %s", err) | |||||
| } | |||||
| pub := testPublicKeys["dsa"] | |||||
| sess, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("NewSession: %v", err) | |||||
| } | |||||
| if err := agent.RequestAgentForwarding(sess); err != nil { | |||||
| t.Fatalf("RequestAgentForwarding: %v", err) | |||||
| } | |||||
| if err := agent.ForwardToAgent(conn, keyring); err != nil { | |||||
| t.Fatalf("SetupForwardKeyring: %v", err) | |||||
| } | |||||
| out, err := sess.CombinedOutput("ssh-add -L") | |||||
| if err != nil { | |||||
| t.Fatalf("running ssh-add: %v, out %s", err, out) | |||||
| } | |||||
| key, _, _, _, err := ssh.ParseAuthorizedKey(out) | |||||
| if err != nil { | |||||
| t.Fatalf("ParseAuthorizedKey(%q): %v", out, err) | |||||
| } | |||||
| if !bytes.Equal(key.Marshal(), pub.Marshal()) { | |||||
| t.Fatalf("got key %s, want %s", ssh.MarshalAuthorizedKey(key), ssh.MarshalAuthorizedKey(pub)) | |||||
| } | |||||
| } | |||||
| @@ -1,47 +0,0 @@ | |||||
| // Copyright 2014 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| // +build darwin dragonfly freebsd linux netbsd openbsd | |||||
| package test | |||||
| import ( | |||||
| "crypto/rand" | |||||
| "testing" | |||||
| "golang.org/x/crypto/ssh" | |||||
| ) | |||||
| func TestCertLogin(t *testing.T) { | |||||
| s := newServer(t) | |||||
| defer s.Shutdown() | |||||
| // Use a key different from the default. | |||||
| clientKey := testSigners["dsa"] | |||||
| caAuthKey := testSigners["ecdsa"] | |||||
| cert := &ssh.Certificate{ | |||||
| Key: clientKey.PublicKey(), | |||||
| ValidPrincipals: []string{username()}, | |||||
| CertType: ssh.UserCert, | |||||
| ValidBefore: ssh.CertTimeInfinity, | |||||
| } | |||||
| if err := cert.SignCert(rand.Reader, caAuthKey); err != nil { | |||||
| t.Fatalf("SetSignature: %v", err) | |||||
| } | |||||
| certSigner, err := ssh.NewCertSigner(cert, clientKey) | |||||
| if err != nil { | |||||
| t.Fatalf("NewCertSigner: %v", err) | |||||
| } | |||||
| conf := &ssh.ClientConfig{ | |||||
| User: username(), | |||||
| } | |||||
| conf.Auth = append(conf.Auth, ssh.PublicKeys(certSigner)) | |||||
| client, err := s.TryDial(conf) | |||||
| if err != nil { | |||||
| t.Fatalf("TryDial: %v", err) | |||||
| } | |||||
| client.Close() | |||||
| } | |||||
| @@ -1,7 +0,0 @@ | |||||
| // Copyright 2012 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| // This package contains integration tests for the | |||||
| // golang.org/x/crypto/ssh package. | |||||
| package test | |||||
| @@ -1,160 +0,0 @@ | |||||
| // Copyright 2012 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| // +build darwin dragonfly freebsd linux netbsd openbsd | |||||
| package test | |||||
| import ( | |||||
| "bytes" | |||||
| "io" | |||||
| "io/ioutil" | |||||
| "math/rand" | |||||
| "net" | |||||
| "testing" | |||||
| "time" | |||||
| ) | |||||
| func TestPortForward(t *testing.T) { | |||||
| server := newServer(t) | |||||
| defer server.Shutdown() | |||||
| conn := server.Dial(clientConfig()) | |||||
| defer conn.Close() | |||||
| sshListener, err := conn.Listen("tcp", "localhost:0") | |||||
| if err != nil { | |||||
| t.Fatal(err) | |||||
| } | |||||
| go func() { | |||||
| sshConn, err := sshListener.Accept() | |||||
| if err != nil { | |||||
| t.Fatalf("listen.Accept failed: %v", err) | |||||
| } | |||||
| _, err = io.Copy(sshConn, sshConn) | |||||
| if err != nil && err != io.EOF { | |||||
| t.Fatalf("ssh client copy: %v", err) | |||||
| } | |||||
| sshConn.Close() | |||||
| }() | |||||
| forwardedAddr := sshListener.Addr().String() | |||||
| tcpConn, err := net.Dial("tcp", forwardedAddr) | |||||
| if err != nil { | |||||
| t.Fatalf("TCP dial failed: %v", err) | |||||
| } | |||||
| readChan := make(chan []byte) | |||||
| go func() { | |||||
| data, _ := ioutil.ReadAll(tcpConn) | |||||
| readChan <- data | |||||
| }() | |||||
| // Invent some data. | |||||
| data := make([]byte, 100*1000) | |||||
| for i := range data { | |||||
| data[i] = byte(i % 255) | |||||
| } | |||||
| var sent []byte | |||||
| for len(sent) < 1000*1000 { | |||||
| // Send random sized chunks | |||||
| m := rand.Intn(len(data)) | |||||
| n, err := tcpConn.Write(data[:m]) | |||||
| if err != nil { | |||||
| break | |||||
| } | |||||
| sent = append(sent, data[:n]...) | |||||
| } | |||||
| if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil { | |||||
| t.Errorf("tcpConn.CloseWrite: %v", err) | |||||
| } | |||||
| read := <-readChan | |||||
| if len(sent) != len(read) { | |||||
| t.Fatalf("got %d bytes, want %d", len(read), len(sent)) | |||||
| } | |||||
| if bytes.Compare(sent, read) != 0 { | |||||
| t.Fatalf("read back data does not match") | |||||
| } | |||||
| if err := sshListener.Close(); err != nil { | |||||
| t.Fatalf("sshListener.Close: %v", err) | |||||
| } | |||||
| // Check that the forward disappeared. | |||||
| tcpConn, err = net.Dial("tcp", forwardedAddr) | |||||
| if err == nil { | |||||
| tcpConn.Close() | |||||
| t.Errorf("still listening to %s after closing", forwardedAddr) | |||||
| } | |||||
| } | |||||
| func TestAcceptClose(t *testing.T) { | |||||
| server := newServer(t) | |||||
| defer server.Shutdown() | |||||
| conn := server.Dial(clientConfig()) | |||||
| sshListener, err := conn.Listen("tcp", "localhost:0") | |||||
| if err != nil { | |||||
| t.Fatal(err) | |||||
| } | |||||
| quit := make(chan error, 1) | |||||
| go func() { | |||||
| for { | |||||
| c, err := sshListener.Accept() | |||||
| if err != nil { | |||||
| quit <- err | |||||
| break | |||||
| } | |||||
| c.Close() | |||||
| } | |||||
| }() | |||||
| sshListener.Close() | |||||
| select { | |||||
| case <-time.After(1 * time.Second): | |||||
| t.Errorf("timeout: listener did not close.") | |||||
| case err := <-quit: | |||||
| t.Logf("quit as expected (error %v)", err) | |||||
| } | |||||
| } | |||||
| // Check that listeners exit if the underlying client transport dies. | |||||
| func TestPortForwardConnectionClose(t *testing.T) { | |||||
| server := newServer(t) | |||||
| defer server.Shutdown() | |||||
| conn := server.Dial(clientConfig()) | |||||
| sshListener, err := conn.Listen("tcp", "localhost:0") | |||||
| if err != nil { | |||||
| t.Fatal(err) | |||||
| } | |||||
| quit := make(chan error, 1) | |||||
| go func() { | |||||
| for { | |||||
| c, err := sshListener.Accept() | |||||
| if err != nil { | |||||
| quit <- err | |||||
| break | |||||
| } | |||||
| c.Close() | |||||
| } | |||||
| }() | |||||
| // It would be even nicer if we closed the server side, but it | |||||
| // is more involved as the fd for that side is dup()ed. | |||||
| server.clientConn.Close() | |||||
| select { | |||||
| case <-time.After(1 * time.Second): | |||||
| t.Errorf("timeout: listener did not close.") | |||||
| case err := <-quit: | |||||
| t.Logf("quit as expected (error %v)", err) | |||||
| } | |||||
| } | |||||
| @@ -1,340 +0,0 @@ | |||||
| // Copyright 2012 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| // +build !windows | |||||
| package test | |||||
| // Session functional tests. | |||||
| import ( | |||||
| "bytes" | |||||
| "errors" | |||||
| "io" | |||||
| "strings" | |||||
| "testing" | |||||
| "golang.org/x/crypto/ssh" | |||||
| ) | |||||
| func TestRunCommandSuccess(t *testing.T) { | |||||
| server := newServer(t) | |||||
| defer server.Shutdown() | |||||
| conn := server.Dial(clientConfig()) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("session failed: %v", err) | |||||
| } | |||||
| defer session.Close() | |||||
| err = session.Run("true") | |||||
| if err != nil { | |||||
| t.Fatalf("session failed: %v", err) | |||||
| } | |||||
| } | |||||
| func TestHostKeyCheck(t *testing.T) { | |||||
| server := newServer(t) | |||||
| defer server.Shutdown() | |||||
| conf := clientConfig() | |||||
| hostDB := hostKeyDB() | |||||
| conf.HostKeyCallback = hostDB.Check | |||||
| // change the keys. | |||||
| hostDB.keys[ssh.KeyAlgoRSA][25]++ | |||||
| hostDB.keys[ssh.KeyAlgoDSA][25]++ | |||||
| hostDB.keys[ssh.KeyAlgoECDSA256][25]++ | |||||
| conn, err := server.TryDial(conf) | |||||
| if err == nil { | |||||
| conn.Close() | |||||
| t.Fatalf("dial should have failed.") | |||||
| } else if !strings.Contains(err.Error(), "host key mismatch") { | |||||
| t.Fatalf("'host key mismatch' not found in %v", err) | |||||
| } | |||||
| } | |||||
| func TestRunCommandStdin(t *testing.T) { | |||||
| server := newServer(t) | |||||
| defer server.Shutdown() | |||||
| conn := server.Dial(clientConfig()) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("session failed: %v", err) | |||||
| } | |||||
| defer session.Close() | |||||
| r, w := io.Pipe() | |||||
| defer r.Close() | |||||
| defer w.Close() | |||||
| session.Stdin = r | |||||
| err = session.Run("true") | |||||
| if err != nil { | |||||
| t.Fatalf("session failed: %v", err) | |||||
| } | |||||
| } | |||||
| func TestRunCommandStdinError(t *testing.T) { | |||||
| server := newServer(t) | |||||
| defer server.Shutdown() | |||||
| conn := server.Dial(clientConfig()) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("session failed: %v", err) | |||||
| } | |||||
| defer session.Close() | |||||
| r, w := io.Pipe() | |||||
| defer r.Close() | |||||
| session.Stdin = r | |||||
| pipeErr := errors.New("closing write end of pipe") | |||||
| w.CloseWithError(pipeErr) | |||||
| err = session.Run("true") | |||||
| if err != pipeErr { | |||||
| t.Fatalf("expected %v, found %v", pipeErr, err) | |||||
| } | |||||
| } | |||||
| func TestRunCommandFailed(t *testing.T) { | |||||
| server := newServer(t) | |||||
| defer server.Shutdown() | |||||
| conn := server.Dial(clientConfig()) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("session failed: %v", err) | |||||
| } | |||||
| defer session.Close() | |||||
| err = session.Run(`bash -c "kill -9 $$"`) | |||||
| if err == nil { | |||||
| t.Fatalf("session succeeded: %v", err) | |||||
| } | |||||
| } | |||||
| func TestRunCommandWeClosed(t *testing.T) { | |||||
| server := newServer(t) | |||||
| defer server.Shutdown() | |||||
| conn := server.Dial(clientConfig()) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("session failed: %v", err) | |||||
| } | |||||
| err = session.Shell() | |||||
| if err != nil { | |||||
| t.Fatalf("shell failed: %v", err) | |||||
| } | |||||
| err = session.Close() | |||||
| if err != nil { | |||||
| t.Fatalf("shell failed: %v", err) | |||||
| } | |||||
| } | |||||
| func TestFuncLargeRead(t *testing.T) { | |||||
| server := newServer(t) | |||||
| defer server.Shutdown() | |||||
| conn := server.Dial(clientConfig()) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("unable to create new session: %s", err) | |||||
| } | |||||
| stdout, err := session.StdoutPipe() | |||||
| if err != nil { | |||||
| t.Fatalf("unable to acquire stdout pipe: %s", err) | |||||
| } | |||||
| err = session.Start("dd if=/dev/urandom bs=2048 count=1024") | |||||
| if err != nil { | |||||
| t.Fatalf("unable to execute remote command: %s", err) | |||||
| } | |||||
| buf := new(bytes.Buffer) | |||||
| n, err := io.Copy(buf, stdout) | |||||
| if err != nil { | |||||
| t.Fatalf("error reading from remote stdout: %s", err) | |||||
| } | |||||
| if n != 2048*1024 { | |||||
| t.Fatalf("Expected %d bytes but read only %d from remote command", 2048, n) | |||||
| } | |||||
| } | |||||
| func TestKeyChange(t *testing.T) { | |||||
| server := newServer(t) | |||||
| defer server.Shutdown() | |||||
| conf := clientConfig() | |||||
| hostDB := hostKeyDB() | |||||
| conf.HostKeyCallback = hostDB.Check | |||||
| conf.RekeyThreshold = 1024 | |||||
| conn := server.Dial(conf) | |||||
| defer conn.Close() | |||||
| for i := 0; i < 4; i++ { | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("unable to create new session: %s", err) | |||||
| } | |||||
| stdout, err := session.StdoutPipe() | |||||
| if err != nil { | |||||
| t.Fatalf("unable to acquire stdout pipe: %s", err) | |||||
| } | |||||
| err = session.Start("dd if=/dev/urandom bs=1024 count=1") | |||||
| if err != nil { | |||||
| t.Fatalf("unable to execute remote command: %s", err) | |||||
| } | |||||
| buf := new(bytes.Buffer) | |||||
| n, err := io.Copy(buf, stdout) | |||||
| if err != nil { | |||||
| t.Fatalf("error reading from remote stdout: %s", err) | |||||
| } | |||||
| want := int64(1024) | |||||
| if n != want { | |||||
| t.Fatalf("Expected %d bytes but read only %d from remote command", want, n) | |||||
| } | |||||
| } | |||||
| if changes := hostDB.checkCount; changes < 4 { | |||||
| t.Errorf("got %d key changes, want 4", changes) | |||||
| } | |||||
| } | |||||
| func TestInvalidTerminalMode(t *testing.T) { | |||||
| server := newServer(t) | |||||
| defer server.Shutdown() | |||||
| conn := server.Dial(clientConfig()) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("session failed: %v", err) | |||||
| } | |||||
| defer session.Close() | |||||
| if err = session.RequestPty("vt100", 80, 40, ssh.TerminalModes{255: 1984}); err == nil { | |||||
| t.Fatalf("req-pty failed: successful request with invalid mode") | |||||
| } | |||||
| } | |||||
| func TestValidTerminalMode(t *testing.T) { | |||||
| server := newServer(t) | |||||
| defer server.Shutdown() | |||||
| conn := server.Dial(clientConfig()) | |||||
| defer conn.Close() | |||||
| session, err := conn.NewSession() | |||||
| if err != nil { | |||||
| t.Fatalf("session failed: %v", err) | |||||
| } | |||||
| defer session.Close() | |||||
| stdout, err := session.StdoutPipe() | |||||
| if err != nil { | |||||
| t.Fatalf("unable to acquire stdout pipe: %s", err) | |||||
| } | |||||
| stdin, err := session.StdinPipe() | |||||
| if err != nil { | |||||
| t.Fatalf("unable to acquire stdin pipe: %s", err) | |||||
| } | |||||
| tm := ssh.TerminalModes{ssh.ECHO: 0} | |||||
| if err = session.RequestPty("xterm", 80, 40, tm); err != nil { | |||||
| t.Fatalf("req-pty failed: %s", err) | |||||
| } | |||||
| err = session.Shell() | |||||
| if err != nil { | |||||
| t.Fatalf("session failed: %s", err) | |||||
| } | |||||
| stdin.Write([]byte("stty -a && exit\n")) | |||||
| var buf bytes.Buffer | |||||
| if _, err := io.Copy(&buf, stdout); err != nil { | |||||
| t.Fatalf("reading failed: %s", err) | |||||
| } | |||||
| if sttyOutput := buf.String(); !strings.Contains(sttyOutput, "-echo ") { | |||||
| t.Fatalf("terminal mode failure: expected -echo in stty output, got %s", sttyOutput) | |||||
| } | |||||
| } | |||||
| func TestCiphers(t *testing.T) { | |||||
| var config ssh.Config | |||||
| config.SetDefaults() | |||||
| cipherOrder := config.Ciphers | |||||
| // This cipher will not be tested when commented out in cipher.go it will | |||||
| // fallback to the next available as per line 292. | |||||
| cipherOrder = append(cipherOrder, "aes128-cbc") | |||||
| for _, ciph := range cipherOrder { | |||||
| server := newServer(t) | |||||
| defer server.Shutdown() | |||||
| conf := clientConfig() | |||||
| conf.Ciphers = []string{ciph} | |||||
| // Don't fail if sshd doesnt have the cipher. | |||||
| conf.Ciphers = append(conf.Ciphers, cipherOrder...) | |||||
| conn, err := server.TryDial(conf) | |||||
| if err == nil { | |||||
| conn.Close() | |||||
| } else { | |||||
| t.Fatalf("failed for cipher %q", ciph) | |||||
| } | |||||
| } | |||||
| } | |||||
| func TestMACs(t *testing.T) { | |||||
| var config ssh.Config | |||||
| config.SetDefaults() | |||||
| macOrder := config.MACs | |||||
| for _, mac := range macOrder { | |||||
| server := newServer(t) | |||||
| defer server.Shutdown() | |||||
| conf := clientConfig() | |||||
| conf.MACs = []string{mac} | |||||
| // Don't fail if sshd doesnt have the MAC. | |||||
| conf.MACs = append(conf.MACs, macOrder...) | |||||
| if conn, err := server.TryDial(conf); err == nil { | |||||
| conn.Close() | |||||
| } else { | |||||
| t.Fatalf("failed for MAC %q", mac) | |||||
| } | |||||
| } | |||||
| } | |||||
| func TestKeyExchanges(t *testing.T) { | |||||
| var config ssh.Config | |||||
| config.SetDefaults() | |||||
| kexOrder := config.KeyExchanges | |||||
| for _, kex := range kexOrder { | |||||
| server := newServer(t) | |||||
| defer server.Shutdown() | |||||
| conf := clientConfig() | |||||
| // Don't fail if sshd doesnt have the kex. | |||||
| conf.KeyExchanges = append([]string{kex}, kexOrder...) | |||||
| conn, err := server.TryDial(conf) | |||||
| if err == nil { | |||||
| conn.Close() | |||||
| } else { | |||||
| t.Errorf("failed for kex %q", kex) | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,46 +0,0 @@ | |||||
| // Copyright 2012 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| // +build !windows | |||||
| package test | |||||
| // direct-tcpip functional tests | |||||
| import ( | |||||
| "io" | |||||
| "net" | |||||
| "testing" | |||||
| ) | |||||
| func TestDial(t *testing.T) { | |||||
| server := newServer(t) | |||||
| defer server.Shutdown() | |||||
| sshConn := server.Dial(clientConfig()) | |||||
| defer sshConn.Close() | |||||
| l, err := net.Listen("tcp", "127.0.0.1:0") | |||||
| if err != nil { | |||||
| t.Fatalf("Listen: %v", err) | |||||
| } | |||||
| defer l.Close() | |||||
| go func() { | |||||
| for { | |||||
| c, err := l.Accept() | |||||
| if err != nil { | |||||
| break | |||||
| } | |||||
| io.WriteString(c, c.RemoteAddr().String()) | |||||
| c.Close() | |||||
| } | |||||
| }() | |||||
| conn, err := sshConn.Dial("tcp", l.Addr().String()) | |||||
| if err != nil { | |||||
| t.Fatalf("Dial: %v", err) | |||||
| } | |||||
| defer conn.Close() | |||||
| } | |||||
| @@ -1,261 +0,0 @@ | |||||
| // Copyright 2012 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| // +build darwin dragonfly freebsd linux netbsd openbsd plan9 | |||||
| package test | |||||
| // functional test harness for unix. | |||||
| import ( | |||||
| "bytes" | |||||
| "fmt" | |||||
| "io/ioutil" | |||||
| "log" | |||||
| "net" | |||||
| "os" | |||||
| "os/exec" | |||||
| "os/user" | |||||
| "path/filepath" | |||||
| "testing" | |||||
| "text/template" | |||||
| "golang.org/x/crypto/ssh" | |||||
| "golang.org/x/crypto/ssh/testdata" | |||||
| ) | |||||
| const sshd_config = ` | |||||
| Protocol 2 | |||||
| HostKey {{.Dir}}/id_rsa | |||||
| HostKey {{.Dir}}/id_dsa | |||||
| HostKey {{.Dir}}/id_ecdsa | |||||
| Pidfile {{.Dir}}/sshd.pid | |||||
| #UsePrivilegeSeparation no | |||||
| KeyRegenerationInterval 3600 | |||||
| ServerKeyBits 768 | |||||
| SyslogFacility AUTH | |||||
| LogLevel DEBUG2 | |||||
| LoginGraceTime 120 | |||||
| PermitRootLogin no | |||||
| StrictModes no | |||||
| RSAAuthentication yes | |||||
| PubkeyAuthentication yes | |||||
| AuthorizedKeysFile {{.Dir}}/id_user.pub | |||||
| TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub | |||||
| IgnoreRhosts yes | |||||
| RhostsRSAAuthentication no | |||||
| HostbasedAuthentication no | |||||
| ` | |||||
| var configTmpl = template.Must(template.New("").Parse(sshd_config)) | |||||
| type server struct { | |||||
| t *testing.T | |||||
| cleanup func() // executed during Shutdown | |||||
| configfile string | |||||
| cmd *exec.Cmd | |||||
| output bytes.Buffer // holds stderr from sshd process | |||||
| // Client half of the network connection. | |||||
| clientConn net.Conn | |||||
| } | |||||
| func username() string { | |||||
| var username string | |||||
| if user, err := user.Current(); err == nil { | |||||
| username = user.Username | |||||
| } else { | |||||
| // user.Current() currently requires cgo. If an error is | |||||
| // returned attempt to get the username from the environment. | |||||
| log.Printf("user.Current: %v; falling back on $USER", err) | |||||
| username = os.Getenv("USER") | |||||
| } | |||||
| if username == "" { | |||||
| panic("Unable to get username") | |||||
| } | |||||
| return username | |||||
| } | |||||
| type storedHostKey struct { | |||||
| // keys map from an algorithm string to binary key data. | |||||
| keys map[string][]byte | |||||
| // checkCount counts the Check calls. Used for testing | |||||
| // rekeying. | |||||
| checkCount int | |||||
| } | |||||
| func (k *storedHostKey) Add(key ssh.PublicKey) { | |||||
| if k.keys == nil { | |||||
| k.keys = map[string][]byte{} | |||||
| } | |||||
| k.keys[key.Type()] = key.Marshal() | |||||
| } | |||||
| func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error { | |||||
| k.checkCount++ | |||||
| algo := key.Type() | |||||
| if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 { | |||||
| return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo]) | |||||
| } | |||||
| return nil | |||||
| } | |||||
| func hostKeyDB() *storedHostKey { | |||||
| keyChecker := &storedHostKey{} | |||||
| keyChecker.Add(testPublicKeys["ecdsa"]) | |||||
| keyChecker.Add(testPublicKeys["rsa"]) | |||||
| keyChecker.Add(testPublicKeys["dsa"]) | |||||
| return keyChecker | |||||
| } | |||||
| func clientConfig() *ssh.ClientConfig { | |||||
| config := &ssh.ClientConfig{ | |||||
| User: username(), | |||||
| Auth: []ssh.AuthMethod{ | |||||
| ssh.PublicKeys(testSigners["user"]), | |||||
| }, | |||||
| HostKeyCallback: hostKeyDB().Check, | |||||
| } | |||||
| return config | |||||
| } | |||||
| // unixConnection creates two halves of a connected net.UnixConn. It | |||||
| // is used for connecting the Go SSH client with sshd without opening | |||||
| // ports. | |||||
| func unixConnection() (*net.UnixConn, *net.UnixConn, error) { | |||||
| dir, err := ioutil.TempDir("", "unixConnection") | |||||
| if err != nil { | |||||
| return nil, nil, err | |||||
| } | |||||
| defer os.Remove(dir) | |||||
| addr := filepath.Join(dir, "ssh") | |||||
| listener, err := net.Listen("unix", addr) | |||||
| if err != nil { | |||||
| return nil, nil, err | |||||
| } | |||||
| defer listener.Close() | |||||
| c1, err := net.Dial("unix", addr) | |||||
| if err != nil { | |||||
| return nil, nil, err | |||||
| } | |||||
| c2, err := listener.Accept() | |||||
| if err != nil { | |||||
| c1.Close() | |||||
| return nil, nil, err | |||||
| } | |||||
| return c1.(*net.UnixConn), c2.(*net.UnixConn), nil | |||||
| } | |||||
| func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) { | |||||
| sshd, err := exec.LookPath("sshd") | |||||
| if err != nil { | |||||
| s.t.Skipf("skipping test: %v", err) | |||||
| } | |||||
| c1, c2, err := unixConnection() | |||||
| if err != nil { | |||||
| s.t.Fatalf("unixConnection: %v", err) | |||||
| } | |||||
| s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e") | |||||
| f, err := c2.File() | |||||
| if err != nil { | |||||
| s.t.Fatalf("UnixConn.File: %v", err) | |||||
| } | |||||
| defer f.Close() | |||||
| s.cmd.Stdin = f | |||||
| s.cmd.Stdout = f | |||||
| s.cmd.Stderr = &s.output | |||||
| if err := s.cmd.Start(); err != nil { | |||||
| s.t.Fail() | |||||
| s.Shutdown() | |||||
| s.t.Fatalf("s.cmd.Start: %v", err) | |||||
| } | |||||
| s.clientConn = c1 | |||||
| conn, chans, reqs, err := ssh.NewClientConn(c1, "", config) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return ssh.NewClient(conn, chans, reqs), nil | |||||
| } | |||||
| func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client { | |||||
| conn, err := s.TryDial(config) | |||||
| if err != nil { | |||||
| s.t.Fail() | |||||
| s.Shutdown() | |||||
| s.t.Fatalf("ssh.Client: %v", err) | |||||
| } | |||||
| return conn | |||||
| } | |||||
| func (s *server) Shutdown() { | |||||
| if s.cmd != nil && s.cmd.Process != nil { | |||||
| // Don't check for errors; if it fails it's most | |||||
| // likely "os: process already finished", and we don't | |||||
| // care about that. Use os.Interrupt, so child | |||||
| // processes are killed too. | |||||
| s.cmd.Process.Signal(os.Interrupt) | |||||
| s.cmd.Wait() | |||||
| } | |||||
| if s.t.Failed() { | |||||
| // log any output from sshd process | |||||
| s.t.Logf("sshd: %s", s.output.String()) | |||||
| } | |||||
| s.cleanup() | |||||
| } | |||||
| func writeFile(path string, contents []byte) { | |||||
| f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600) | |||||
| if err != nil { | |||||
| panic(err) | |||||
| } | |||||
| defer f.Close() | |||||
| if _, err := f.Write(contents); err != nil { | |||||
| panic(err) | |||||
| } | |||||
| } | |||||
| // newServer returns a new mock ssh server. | |||||
| func newServer(t *testing.T) *server { | |||||
| if testing.Short() { | |||||
| t.Skip("skipping test due to -short") | |||||
| } | |||||
| dir, err := ioutil.TempDir("", "sshtest") | |||||
| if err != nil { | |||||
| t.Fatal(err) | |||||
| } | |||||
| f, err := os.Create(filepath.Join(dir, "sshd_config")) | |||||
| if err != nil { | |||||
| t.Fatal(err) | |||||
| } | |||||
| err = configTmpl.Execute(f, map[string]string{ | |||||
| "Dir": dir, | |||||
| }) | |||||
| if err != nil { | |||||
| t.Fatal(err) | |||||
| } | |||||
| f.Close() | |||||
| for k, v := range testdata.PEMBytes { | |||||
| filename := "id_" + k | |||||
| writeFile(filepath.Join(dir, filename), v) | |||||
| writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k])) | |||||
| } | |||||
| return &server{ | |||||
| t: t, | |||||
| configfile: f.Name(), | |||||
| cleanup: func() { | |||||
| if err := os.RemoveAll(dir); err != nil { | |||||
| t.Error(err) | |||||
| } | |||||
| }, | |||||
| } | |||||
| } | |||||
| @@ -1,64 +0,0 @@ | |||||
| // Copyright 2014 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| // IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places: | |||||
| // ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three | |||||
| // instances. | |||||
| package test | |||||
| import ( | |||||
| "crypto/rand" | |||||
| "fmt" | |||||
| "golang.org/x/crypto/ssh" | |||||
| "golang.org/x/crypto/ssh/testdata" | |||||
| ) | |||||
| var ( | |||||
| testPrivateKeys map[string]interface{} | |||||
| testSigners map[string]ssh.Signer | |||||
| testPublicKeys map[string]ssh.PublicKey | |||||
| ) | |||||
| func init() { | |||||
| var err error | |||||
| n := len(testdata.PEMBytes) | |||||
| testPrivateKeys = make(map[string]interface{}, n) | |||||
| testSigners = make(map[string]ssh.Signer, n) | |||||
| testPublicKeys = make(map[string]ssh.PublicKey, n) | |||||
| for t, k := range testdata.PEMBytes { | |||||
| testPrivateKeys[t], err = ssh.ParseRawPrivateKey(k) | |||||
| if err != nil { | |||||
| panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) | |||||
| } | |||||
| testSigners[t], err = ssh.NewSignerFromKey(testPrivateKeys[t]) | |||||
| if err != nil { | |||||
| panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) | |||||
| } | |||||
| testPublicKeys[t] = testSigners[t].PublicKey() | |||||
| } | |||||
| // Create a cert and sign it for use in tests. | |||||
| testCert := &ssh.Certificate{ | |||||
| Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil | |||||
| ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage | |||||
| ValidAfter: 0, // unix epoch | |||||
| ValidBefore: ssh.CertTimeInfinity, // The end of currently representable time. | |||||
| Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil | |||||
| Key: testPublicKeys["ecdsa"], | |||||
| SignatureKey: testPublicKeys["rsa"], | |||||
| Permissions: ssh.Permissions{ | |||||
| CriticalOptions: map[string]string{}, | |||||
| Extensions: map[string]string{}, | |||||
| }, | |||||
| } | |||||
| testCert.SignCert(rand.Reader, testSigners["rsa"]) | |||||
| testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] | |||||
| testSigners["cert"], err = ssh.NewCertSigner(testCert, testSigners["ecdsa"]) | |||||
| if err != nil { | |||||
| panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) | |||||
| } | |||||
| } | |||||
| @@ -1,8 +0,0 @@ | |||||
| // Copyright 2014 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| // This package contains test data shared between the various subpackages of | |||||
| // the golang.org/x/crypto/ssh package. Under no circumstance should | |||||
| // this data be used for production code. | |||||
| package testdata | |||||
| @@ -1,43 +0,0 @@ | |||||
| // Copyright 2014 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package testdata | |||||
| var PEMBytes = map[string][]byte{ | |||||
| "dsa": []byte(`-----BEGIN DSA PRIVATE KEY----- | |||||
| MIIBuwIBAAKBgQD6PDSEyXiI9jfNs97WuM46MSDCYlOqWw80ajN16AohtBncs1YB | |||||
| lHk//dQOvCYOsYaE+gNix2jtoRjwXhDsc25/IqQbU1ahb7mB8/rsaILRGIbA5WH3 | |||||
| EgFtJmXFovDz3if6F6TzvhFpHgJRmLYVR8cqsezL3hEZOvvs2iH7MorkxwIVAJHD | |||||
| nD82+lxh2fb4PMsIiaXudAsBAoGAQRf7Q/iaPRn43ZquUhd6WwvirqUj+tkIu6eV | |||||
| 2nZWYmXLlqFQKEy4Tejl7Wkyzr2OSYvbXLzo7TNxLKoWor6ips0phYPPMyXld14r | |||||
| juhT24CrhOzuLMhDduMDi032wDIZG4Y+K7ElU8Oufn8Sj5Wge8r6ANmmVgmFfynr | |||||
| FhdYCngCgYEA3ucGJ93/Mx4q4eKRDxcWD3QzWyqpbRVRRV1Vmih9Ha/qC994nJFz | |||||
| DQIdjxDIT2Rk2AGzMqFEB68Zc3O+Wcsmz5eWWzEwFxaTwOGWTyDqsDRLm3fD+QYj | |||||
| nOwuxb0Kce+gWI8voWcqC9cyRm09jGzu2Ab3Bhtpg8JJ8L7gS3MRZK4CFEx4UAfY | |||||
| Fmsr0W6fHB9nhS4/UXM8 | |||||
| -----END DSA PRIVATE KEY----- | |||||
| `), | |||||
| "ecdsa": []byte(`-----BEGIN EC PRIVATE KEY----- | |||||
| MHcCAQEEINGWx0zo6fhJ/0EAfrPzVFyFC9s18lBt3cRoEDhS3ARooAoGCCqGSM49 | |||||
| AwEHoUQDQgAEi9Hdw6KvZcWxfg2IDhA7UkpDtzzt6ZqJXSsFdLd+Kx4S3Sx4cVO+ | |||||
| 6/ZOXRnPmNAlLUqjShUsUBBngG0u2fqEqA== | |||||
| -----END EC PRIVATE KEY----- | |||||
| `), | |||||
| "rsa": []byte(`-----BEGIN RSA PRIVATE KEY----- | |||||
| MIIBOwIBAAJBALdGZxkXDAjsYk10ihwU6Id2KeILz1TAJuoq4tOgDWxEEGeTrcld | |||||
| r/ZwVaFzjWzxaf6zQIJbfaSEAhqD5yo72+sCAwEAAQJBAK8PEVU23Wj8mV0QjwcJ | |||||
| tZ4GcTUYQL7cF4+ezTCE9a1NrGnCP2RuQkHEKxuTVrxXt+6OF15/1/fuXnxKjmJC | |||||
| nxkCIQDaXvPPBi0c7vAxGwNY9726x01/dNbHCE0CBtcotobxpwIhANbbQbh3JHVW | |||||
| 2haQh4fAG5mhesZKAGcxTyv4mQ7uMSQdAiAj+4dzMpJWdSzQ+qGHlHMIBvVHLkqB | |||||
| y2VdEyF7DPCZewIhAI7GOI/6LDIFOvtPo6Bj2nNmyQ1HU6k/LRtNIXi4c9NJAiAr | |||||
| rrxx26itVhJmcvoUhOjwuzSlP2bE5VHAvkGB352YBg== | |||||
| -----END RSA PRIVATE KEY----- | |||||
| `), | |||||
| "user": []byte(`-----BEGIN EC PRIVATE KEY----- | |||||
| MHcCAQEEILYCAeq8f7V4vSSypRw7pxy8yz3V5W4qg8kSC3zJhqpQoAoGCCqGSM49 | |||||
| AwEHoUQDQgAEYcO2xNKiRUYOLEHM7VYAp57HNyKbOdYtHD83Z4hzNPVC4tM5mdGD | |||||
| PLL8IEwvYu2wq+lpXfGQnNMbzYf9gspG0w== | |||||
| -----END EC PRIVATE KEY----- | |||||
| `), | |||||
| } | |||||
| @@ -1,63 +0,0 @@ | |||||
| // Copyright 2014 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| // IMPLEMENTOR NOTE: To avoid a package loop, this file is in three places: | |||||
| // ssh/, ssh/agent, and ssh/test/. It should be kept in sync across all three | |||||
| // instances. | |||||
| package ssh | |||||
| import ( | |||||
| "crypto/rand" | |||||
| "fmt" | |||||
| "github.com/gogits/gogs/modules/crypto/ssh/testdata" | |||||
| ) | |||||
| var ( | |||||
| testPrivateKeys map[string]interface{} | |||||
| testSigners map[string]Signer | |||||
| testPublicKeys map[string]PublicKey | |||||
| ) | |||||
| func init() { | |||||
| var err error | |||||
| n := len(testdata.PEMBytes) | |||||
| testPrivateKeys = make(map[string]interface{}, n) | |||||
| testSigners = make(map[string]Signer, n) | |||||
| testPublicKeys = make(map[string]PublicKey, n) | |||||
| for t, k := range testdata.PEMBytes { | |||||
| testPrivateKeys[t], err = ParseRawPrivateKey(k) | |||||
| if err != nil { | |||||
| panic(fmt.Sprintf("Unable to parse test key %s: %v", t, err)) | |||||
| } | |||||
| testSigners[t], err = NewSignerFromKey(testPrivateKeys[t]) | |||||
| if err != nil { | |||||
| panic(fmt.Sprintf("Unable to create signer for test key %s: %v", t, err)) | |||||
| } | |||||
| testPublicKeys[t] = testSigners[t].PublicKey() | |||||
| } | |||||
| // Create a cert and sign it for use in tests. | |||||
| testCert := &Certificate{ | |||||
| Nonce: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil | |||||
| ValidPrincipals: []string{"gopher1", "gopher2"}, // increases test coverage | |||||
| ValidAfter: 0, // unix epoch | |||||
| ValidBefore: CertTimeInfinity, // The end of currently representable time. | |||||
| Reserved: []byte{}, // To pass reflect.DeepEqual after marshal & parse, this must be non-nil | |||||
| Key: testPublicKeys["ecdsa"], | |||||
| SignatureKey: testPublicKeys["rsa"], | |||||
| Permissions: Permissions{ | |||||
| CriticalOptions: map[string]string{}, | |||||
| Extensions: map[string]string{}, | |||||
| }, | |||||
| } | |||||
| testCert.SignCert(rand.Reader, testSigners["rsa"]) | |||||
| testPrivateKeys["cert"] = testPrivateKeys["ecdsa"] | |||||
| testSigners["cert"], err = NewCertSigner(testCert, testSigners["ecdsa"]) | |||||
| if err != nil { | |||||
| panic(fmt.Sprintf("Unable to create certificate signer: %v", err)) | |||||
| } | |||||
| } | |||||
| @@ -1,332 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "bufio" | |||||
| "errors" | |||||
| "io" | |||||
| ) | |||||
| const ( | |||||
| gcmCipherID = "aes128-gcm@openssh.com" | |||||
| aes128cbcID = "aes128-cbc" | |||||
| ) | |||||
| // packetConn represents a transport that implements packet based | |||||
| // operations. | |||||
| type packetConn interface { | |||||
| // Encrypt and send a packet of data to the remote peer. | |||||
| writePacket(packet []byte) error | |||||
| // Read a packet from the connection | |||||
| readPacket() ([]byte, error) | |||||
| // Close closes the write-side of the connection. | |||||
| Close() error | |||||
| } | |||||
| // transport is the keyingTransport that implements the SSH packet | |||||
| // protocol. | |||||
| type transport struct { | |||||
| reader connectionState | |||||
| writer connectionState | |||||
| bufReader *bufio.Reader | |||||
| bufWriter *bufio.Writer | |||||
| rand io.Reader | |||||
| io.Closer | |||||
| // Initial H used for the session ID. Once assigned this does | |||||
| // not change, even during subsequent key exchanges. | |||||
| sessionID []byte | |||||
| } | |||||
| // getSessionID returns the ID of the SSH connection. The return value | |||||
| // should not be modified. | |||||
| func (t *transport) getSessionID() []byte { | |||||
| if t.sessionID == nil { | |||||
| panic("session ID not set yet") | |||||
| } | |||||
| return t.sessionID | |||||
| } | |||||
| // packetCipher represents a combination of SSH encryption/MAC | |||||
| // protocol. A single instance should be used for one direction only. | |||||
| type packetCipher interface { | |||||
| // writePacket encrypts the packet and writes it to w. The | |||||
| // contents of the packet are generally scrambled. | |||||
| writePacket(seqnum uint32, w io.Writer, rand io.Reader, packet []byte) error | |||||
| // readPacket reads and decrypts a packet of data. The | |||||
| // returned packet may be overwritten by future calls of | |||||
| // readPacket. | |||||
| readPacket(seqnum uint32, r io.Reader) ([]byte, error) | |||||
| } | |||||
| // connectionState represents one side (read or write) of the | |||||
| // connection. This is necessary because each direction has its own | |||||
| // keys, and can even have its own algorithms | |||||
| type connectionState struct { | |||||
| packetCipher | |||||
| seqNum uint32 | |||||
| dir direction | |||||
| pendingKeyChange chan packetCipher | |||||
| } | |||||
| // prepareKeyChange sets up key material for a keychange. The key changes in | |||||
| // both directions are triggered by reading and writing a msgNewKey packet | |||||
| // respectively. | |||||
| func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error { | |||||
| if t.sessionID == nil { | |||||
| t.sessionID = kexResult.H | |||||
| } | |||||
| kexResult.SessionID = t.sessionID | |||||
| if ciph, err := newPacketCipher(t.reader.dir, algs.r, kexResult); err != nil { | |||||
| return err | |||||
| } else { | |||||
| t.reader.pendingKeyChange <- ciph | |||||
| } | |||||
| if ciph, err := newPacketCipher(t.writer.dir, algs.w, kexResult); err != nil { | |||||
| return err | |||||
| } else { | |||||
| t.writer.pendingKeyChange <- ciph | |||||
| } | |||||
| return nil | |||||
| } | |||||
| // Read and decrypt next packet. | |||||
| func (t *transport) readPacket() ([]byte, error) { | |||||
| return t.reader.readPacket(t.bufReader) | |||||
| } | |||||
| func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) { | |||||
| packet, err := s.packetCipher.readPacket(s.seqNum, r) | |||||
| s.seqNum++ | |||||
| if err == nil && len(packet) == 0 { | |||||
| err = errors.New("ssh: zero length packet") | |||||
| } | |||||
| if len(packet) > 0 && packet[0] == msgNewKeys { | |||||
| select { | |||||
| case cipher := <-s.pendingKeyChange: | |||||
| s.packetCipher = cipher | |||||
| default: | |||||
| return nil, errors.New("ssh: got bogus newkeys message.") | |||||
| } | |||||
| } | |||||
| // The packet may point to an internal buffer, so copy the | |||||
| // packet out here. | |||||
| fresh := make([]byte, len(packet)) | |||||
| copy(fresh, packet) | |||||
| return fresh, err | |||||
| } | |||||
| func (t *transport) writePacket(packet []byte) error { | |||||
| return t.writer.writePacket(t.bufWriter, t.rand, packet) | |||||
| } | |||||
| func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte) error { | |||||
| changeKeys := len(packet) > 0 && packet[0] == msgNewKeys | |||||
| err := s.packetCipher.writePacket(s.seqNum, w, rand, packet) | |||||
| if err != nil { | |||||
| return err | |||||
| } | |||||
| if err = w.Flush(); err != nil { | |||||
| return err | |||||
| } | |||||
| s.seqNum++ | |||||
| if changeKeys { | |||||
| select { | |||||
| case cipher := <-s.pendingKeyChange: | |||||
| s.packetCipher = cipher | |||||
| default: | |||||
| panic("ssh: no key material for msgNewKeys") | |||||
| } | |||||
| } | |||||
| return err | |||||
| } | |||||
| func newTransport(rwc io.ReadWriteCloser, rand io.Reader, isClient bool) *transport { | |||||
| t := &transport{ | |||||
| bufReader: bufio.NewReader(rwc), | |||||
| bufWriter: bufio.NewWriter(rwc), | |||||
| rand: rand, | |||||
| reader: connectionState{ | |||||
| packetCipher: &streamPacketCipher{cipher: noneCipher{}}, | |||||
| pendingKeyChange: make(chan packetCipher, 1), | |||||
| }, | |||||
| writer: connectionState{ | |||||
| packetCipher: &streamPacketCipher{cipher: noneCipher{}}, | |||||
| pendingKeyChange: make(chan packetCipher, 1), | |||||
| }, | |||||
| Closer: rwc, | |||||
| } | |||||
| if isClient { | |||||
| t.reader.dir = serverKeys | |||||
| t.writer.dir = clientKeys | |||||
| } else { | |||||
| t.reader.dir = clientKeys | |||||
| t.writer.dir = serverKeys | |||||
| } | |||||
| return t | |||||
| } | |||||
| type direction struct { | |||||
| ivTag []byte | |||||
| keyTag []byte | |||||
| macKeyTag []byte | |||||
| } | |||||
| var ( | |||||
| serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}} | |||||
| clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}} | |||||
| ) | |||||
| // generateKeys generates key material for IV, MAC and encryption. | |||||
| func generateKeys(d direction, algs directionAlgorithms, kex *kexResult) (iv, key, macKey []byte) { | |||||
| cipherMode := cipherModes[algs.Cipher] | |||||
| macMode := macModes[algs.MAC] | |||||
| iv = make([]byte, cipherMode.ivSize) | |||||
| key = make([]byte, cipherMode.keySize) | |||||
| macKey = make([]byte, macMode.keySize) | |||||
| generateKeyMaterial(iv, d.ivTag, kex) | |||||
| generateKeyMaterial(key, d.keyTag, kex) | |||||
| generateKeyMaterial(macKey, d.macKeyTag, kex) | |||||
| return | |||||
| } | |||||
| // setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as | |||||
| // described in RFC 4253, section 6.4. direction should either be serverKeys | |||||
| // (to setup server->client keys) or clientKeys (for client->server keys). | |||||
| func newPacketCipher(d direction, algs directionAlgorithms, kex *kexResult) (packetCipher, error) { | |||||
| iv, key, macKey := generateKeys(d, algs, kex) | |||||
| if algs.Cipher == gcmCipherID { | |||||
| return newGCMCipher(iv, key, macKey) | |||||
| } | |||||
| if algs.Cipher == aes128cbcID { | |||||
| return newAESCBCCipher(iv, key, macKey, algs) | |||||
| } | |||||
| c := &streamPacketCipher{ | |||||
| mac: macModes[algs.MAC].new(macKey), | |||||
| } | |||||
| c.macResult = make([]byte, c.mac.Size()) | |||||
| var err error | |||||
| c.cipher, err = cipherModes[algs.Cipher].createStream(key, iv) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| return c, nil | |||||
| } | |||||
| // generateKeyMaterial fills out with key material generated from tag, K, H | |||||
| // and sessionId, as specified in RFC 4253, section 7.2. | |||||
| func generateKeyMaterial(out, tag []byte, r *kexResult) { | |||||
| var digestsSoFar []byte | |||||
| h := r.Hash.New() | |||||
| for len(out) > 0 { | |||||
| h.Reset() | |||||
| h.Write(r.K) | |||||
| h.Write(r.H) | |||||
| if len(digestsSoFar) == 0 { | |||||
| h.Write(tag) | |||||
| h.Write(r.SessionID) | |||||
| } else { | |||||
| h.Write(digestsSoFar) | |||||
| } | |||||
| digest := h.Sum(nil) | |||||
| n := copy(out, digest) | |||||
| out = out[n:] | |||||
| if len(out) > 0 { | |||||
| digestsSoFar = append(digestsSoFar, digest...) | |||||
| } | |||||
| } | |||||
| } | |||||
| const packageVersion = "SSH-2.0-Go" | |||||
| // Sends and receives a version line. The versionLine string should | |||||
| // be US ASCII, start with "SSH-2.0-", and should not include a | |||||
| // newline. exchangeVersions returns the other side's version line. | |||||
| func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) { | |||||
| // Contrary to the RFC, we do not ignore lines that don't | |||||
| // start with "SSH-2.0-" to make the library usable with | |||||
| // nonconforming servers. | |||||
| for _, c := range versionLine { | |||||
| // The spec disallows non US-ASCII chars, and | |||||
| // specifically forbids null chars. | |||||
| if c < 32 { | |||||
| return nil, errors.New("ssh: junk character in version line") | |||||
| } | |||||
| } | |||||
| if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil { | |||||
| return | |||||
| } | |||||
| them, err = readVersion(rw) | |||||
| return them, err | |||||
| } | |||||
| // maxVersionStringBytes is the maximum number of bytes that we'll | |||||
| // accept as a version string. RFC 4253 section 4.2 limits this at 255 | |||||
| // chars | |||||
| const maxVersionStringBytes = 255 | |||||
| // Read version string as specified by RFC 4253, section 4.2. | |||||
| func readVersion(r io.Reader) ([]byte, error) { | |||||
| versionString := make([]byte, 0, 64) | |||||
| var ok bool | |||||
| var buf [1]byte | |||||
| for len(versionString) < maxVersionStringBytes { | |||||
| _, err := io.ReadFull(r, buf[:]) | |||||
| if err != nil { | |||||
| return nil, err | |||||
| } | |||||
| // The RFC says that the version should be terminated with \r\n | |||||
| // but several SSH servers actually only send a \n. | |||||
| if buf[0] == '\n' { | |||||
| ok = true | |||||
| break | |||||
| } | |||||
| // non ASCII chars are disallowed, but we are lenient, | |||||
| // since Go doesn't use null-terminated strings. | |||||
| // The RFC allows a comment after a space, however, | |||||
| // all of it (version and comments) goes into the | |||||
| // session hash. | |||||
| versionString = append(versionString, buf[0]) | |||||
| } | |||||
| if !ok { | |||||
| return nil, errors.New("ssh: overflow reading version string") | |||||
| } | |||||
| // There might be a '\r' on the end which we should remove. | |||||
| if len(versionString) > 0 && versionString[len(versionString)-1] == '\r' { | |||||
| versionString = versionString[:len(versionString)-1] | |||||
| } | |||||
| return versionString, nil | |||||
| } | |||||
| @@ -1,109 +0,0 @@ | |||||
| // Copyright 2011 The Go Authors. All rights reserved. | |||||
| // Use of this source code is governed by a BSD-style | |||||
| // license that can be found in the LICENSE file. | |||||
| package ssh | |||||
| import ( | |||||
| "bytes" | |||||
| "crypto/rand" | |||||
| "encoding/binary" | |||||
| "strings" | |||||
| "testing" | |||||
| ) | |||||
| func TestReadVersion(t *testing.T) { | |||||
| longversion := strings.Repeat("SSH-2.0-bla", 50)[:253] | |||||
| cases := map[string]string{ | |||||
| "SSH-2.0-bla\r\n": "SSH-2.0-bla", | |||||
| "SSH-2.0-bla\n": "SSH-2.0-bla", | |||||
| longversion + "\r\n": longversion, | |||||
| } | |||||
| for in, want := range cases { | |||||
| result, err := readVersion(bytes.NewBufferString(in)) | |||||
| if err != nil { | |||||
| t.Errorf("readVersion(%q): %s", in, err) | |||||
| } | |||||
| got := string(result) | |||||
| if got != want { | |||||
| t.Errorf("got %q, want %q", got, want) | |||||
| } | |||||
| } | |||||
| } | |||||
| func TestReadVersionError(t *testing.T) { | |||||
| longversion := strings.Repeat("SSH-2.0-bla", 50)[:253] | |||||
| cases := []string{ | |||||
| longversion + "too-long\r\n", | |||||
| } | |||||
| for _, in := range cases { | |||||
| if _, err := readVersion(bytes.NewBufferString(in)); err == nil { | |||||
| t.Errorf("readVersion(%q) should have failed", in) | |||||
| } | |||||
| } | |||||
| } | |||||
| func TestExchangeVersionsBasic(t *testing.T) { | |||||
| v := "SSH-2.0-bla" | |||||
| buf := bytes.NewBufferString(v + "\r\n") | |||||
| them, err := exchangeVersions(buf, []byte("xyz")) | |||||
| if err != nil { | |||||
| t.Errorf("exchangeVersions: %v", err) | |||||
| } | |||||
| if want := "SSH-2.0-bla"; string(them) != want { | |||||
| t.Errorf("got %q want %q for our version", them, want) | |||||
| } | |||||
| } | |||||
| func TestExchangeVersions(t *testing.T) { | |||||
| cases := []string{ | |||||
| "not\x000allowed", | |||||
| "not allowed\n", | |||||
| } | |||||
| for _, c := range cases { | |||||
| buf := bytes.NewBufferString("SSH-2.0-bla\r\n") | |||||
| if _, err := exchangeVersions(buf, []byte(c)); err == nil { | |||||
| t.Errorf("exchangeVersions(%q): should have failed", c) | |||||
| } | |||||
| } | |||||
| } | |||||
| type closerBuffer struct { | |||||
| bytes.Buffer | |||||
| } | |||||
| func (b *closerBuffer) Close() error { | |||||
| return nil | |||||
| } | |||||
| func TestTransportMaxPacketWrite(t *testing.T) { | |||||
| buf := &closerBuffer{} | |||||
| tr := newTransport(buf, rand.Reader, true) | |||||
| huge := make([]byte, maxPacket+1) | |||||
| err := tr.writePacket(huge) | |||||
| if err == nil { | |||||
| t.Errorf("transport accepted write for a huge packet.") | |||||
| } | |||||
| } | |||||
| func TestTransportMaxPacketReader(t *testing.T) { | |||||
| var header [5]byte | |||||
| huge := make([]byte, maxPacket+128) | |||||
| binary.BigEndian.PutUint32(header[0:], uint32(len(huge))) | |||||
| // padding. | |||||
| header[4] = 0 | |||||
| buf := &closerBuffer{} | |||||
| buf.Write(header[:]) | |||||
| buf.Write(huge) | |||||
| tr := newTransport(buf, rand.Reader, true) | |||||
| _, err := tr.readPacket() | |||||
| if err == nil { | |||||
| t.Errorf("transport succeeded reading huge packet.") | |||||
| } else if !strings.Contains(err.Error(), "large") { | |||||
| t.Errorf("got %q, should mention %q", err.Error(), "large") | |||||
| } | |||||
| } | |||||
| @@ -22,7 +22,6 @@ import ( | |||||
| "github.com/gogits/gogs/modules/bindata" | "github.com/gogits/gogs/modules/bindata" | ||||
| "github.com/gogits/gogs/modules/log" | "github.com/gogits/gogs/modules/log" | ||||
| // "github.com/gogits/gogs/modules/ssh" | |||||
| "github.com/gogits/gogs/modules/user" | "github.com/gogits/gogs/modules/user" | ||||
| ) | ) | ||||
| @@ -51,6 +50,7 @@ var ( | |||||
| AppName string | AppName string | ||||
| AppUrl string | AppUrl string | ||||
| AppSubUrl string | AppSubUrl string | ||||
| AppPath string | |||||
| AppDataPath = "data" | AppDataPath = "data" | ||||
| // Server settings. | // Server settings. | ||||
| @@ -58,8 +58,9 @@ var ( | |||||
| Domain string | Domain string | ||||
| HttpAddr, HttpPort string | HttpAddr, HttpPort string | ||||
| DisableSSH bool | DisableSSH bool | ||||
| SSHPort int | |||||
| StartSSHServer bool | |||||
| SSHDomain string | SSHDomain string | ||||
| SSHPort int | |||||
| OfflineMode bool | OfflineMode bool | ||||
| DisableRouterLog bool | DisableRouterLog bool | ||||
| CertFile, KeyFile string | CertFile, KeyFile string | ||||
| @@ -196,21 +197,27 @@ func DateLang(lang string) string { | |||||
| return "en" | return "en" | ||||
| } | } | ||||
| func init() { | |||||
| IsWindows = runtime.GOOS == "windows" | |||||
| log.NewLogger(0, "console", `{"level": 0}`) | |||||
| } | |||||
| func ExecPath() (string, error) { | |||||
| // execPath returns the executable path. | |||||
| func execPath() (string, error) { | |||||
| file, err := exec.LookPath(os.Args[0]) | file, err := exec.LookPath(os.Args[0]) | ||||
| if err != nil { | if err != nil { | ||||
| return "", err | return "", err | ||||
| } | } | ||||
| p, err := filepath.Abs(file) | |||||
| if err != nil { | |||||
| return "", err | |||||
| return filepath.Abs(file) | |||||
| } | |||||
| func init() { | |||||
| IsWindows = runtime.GOOS == "windows" | |||||
| log.NewLogger(0, "console", `{"level": 0}`) | |||||
| var err error | |||||
| if AppPath, err = execPath(); err != nil { | |||||
| log.Fatal(4, "fail to get app path: %v\n", err) | |||||
| } | } | ||||
| return p, nil | |||||
| // Note: we don't use path.Dir here because it does not handle case | |||||
| // which path starts with two "/" in Windows: "//psf/Home/..." | |||||
| AppPath = strings.Replace(AppPath, "\\", "/", -1) | |||||
| } | } | ||||
| // WorkDir returns absolute path of work directory. | // WorkDir returns absolute path of work directory. | ||||
| @@ -220,19 +227,11 @@ func WorkDir() (string, error) { | |||||
| return wd, nil | return wd, nil | ||||
| } | } | ||||
| execPath, err := ExecPath() | |||||
| if err != nil { | |||||
| return execPath, err | |||||
| } | |||||
| // Note: we don't use path.Dir here because it does not handle case | |||||
| // which path starts with two "/" in Windows: "//psf/Home/..." | |||||
| execPath = strings.Replace(execPath, "\\", "/", -1) | |||||
| i := strings.LastIndex(execPath, "/") | |||||
| i := strings.LastIndex(AppPath, "/") | |||||
| if i == -1 { | if i == -1 { | ||||
| return execPath, nil | |||||
| return AppPath, nil | |||||
| } | } | ||||
| return execPath[:i], nil | |||||
| return AppPath[:i], nil | |||||
| } | } | ||||
| func forcePathSeparator(path string) { | func forcePathSeparator(path string) { | ||||
| @@ -301,6 +300,9 @@ func NewContext() { | |||||
| HttpAddr = sec.Key("HTTP_ADDR").MustString("0.0.0.0") | HttpAddr = sec.Key("HTTP_ADDR").MustString("0.0.0.0") | ||||
| HttpPort = sec.Key("HTTP_PORT").MustString("3000") | HttpPort = sec.Key("HTTP_PORT").MustString("3000") | ||||
| DisableSSH = sec.Key("DISABLE_SSH").MustBool() | DisableSSH = sec.Key("DISABLE_SSH").MustBool() | ||||
| if !DisableSSH { | |||||
| StartSSHServer = sec.Key("START_SSH_SERVER").MustBool() | |||||
| } | |||||
| SSHDomain = sec.Key("SSH_DOMAIN").MustString(Domain) | SSHDomain = sec.Key("SSH_DOMAIN").MustString(Domain) | ||||
| SSHPort = sec.Key("SSH_PORT").MustInt(22) | SSHPort = sec.Key("SSH_PORT").MustInt(22) | ||||
| OfflineMode = sec.Key("OFFLINE_MODE").MustBool() | OfflineMode = sec.Key("OFFLINE_MODE").MustBool() | ||||
| @@ -655,5 +657,4 @@ func NewServices() { | |||||
| newRegisterMailService() | newRegisterMailService() | ||||
| newNotifyMailService() | newNotifyMailService() | ||||
| newWebhookService() | newWebhookService() | ||||
| // ssh.Listen("2222") | |||||
| } | } | ||||
| @@ -1,80 +1,110 @@ | |||||
| // +build go1.4 | |||||
| // Copyright 2014 The Gogs Authors. All rights reserved. | // Copyright 2014 The Gogs Authors. All rights reserved. | ||||
| // Use of this source code is governed by a MIT-style | // Use of this source code is governed by a MIT-style | ||||
| // license that can be found in the LICENSE file. | // license that can be found in the LICENSE file. | ||||
| // Prototype, git client looks like do not recognize req.Reply. | |||||
| package ssh | package ssh | ||||
| import ( | import ( | ||||
| "fmt" | |||||
| "io" | |||||
| "io/ioutil" | "io/ioutil" | ||||
| "net" | "net" | ||||
| "os" | "os" | ||||
| "os/exec" | "os/exec" | ||||
| "path/filepath" | |||||
| "strings" | "strings" | ||||
| "github.com/Unknwon/com" | "github.com/Unknwon/com" | ||||
| "golang.org/x/crypto/ssh" | |||||
| "github.com/gogits/gogs/modules/crypto/ssh" | |||||
| "github.com/gogits/gogs/models" | |||||
| "github.com/gogits/gogs/modules/log" | "github.com/gogits/gogs/modules/log" | ||||
| "github.com/gogits/gogs/modules/setting" | |||||
| ) | ) | ||||
| func handleServerConn(keyId string, chans <-chan ssh.NewChannel) { | |||||
| func cleanCommand(cmd string) string { | |||||
| i := strings.Index(cmd, "git") | |||||
| if i == -1 { | |||||
| return cmd | |||||
| } | |||||
| return cmd[i:] | |||||
| } | |||||
| func handleServerConn(keyID string, chans <-chan ssh.NewChannel) { | |||||
| for newChan := range chans { | for newChan := range chans { | ||||
| if newChan.ChannelType() != "session" { | if newChan.ChannelType() != "session" { | ||||
| newChan.Reject(ssh.UnknownChannelType, "unknown channel type") | newChan.Reject(ssh.UnknownChannelType, "unknown channel type") | ||||
| continue | continue | ||||
| } | } | ||||
| channel, requests, err := newChan.Accept() | |||||
| ch, reqs, err := newChan.Accept() | |||||
| if err != nil { | if err != nil { | ||||
| log.Error(3, "Could not accept channel: %v", err) | |||||
| log.Error(3, "Error accepting channel: %v", err) | |||||
| continue | continue | ||||
| } | } | ||||
| go func(in <-chan *ssh.Request) { | go func(in <-chan *ssh.Request) { | ||||
| defer channel.Close() | |||||
| defer ch.Close() | |||||
| for req := range in { | for req := range in { | ||||
| ok, payload := false, strings.TrimLeft(string(req.Payload), "\x00&") | |||||
| fmt.Println("Request:", req.Type, req.WantReply, payload) | |||||
| if req.WantReply { | |||||
| fmt.Println(req.Reply(true, nil)) | |||||
| } | |||||
| payload := cleanCommand(string(req.Payload)) | |||||
| switch req.Type { | switch req.Type { | ||||
| case "env": | case "env": | ||||
| args := strings.Split(strings.Replace(payload, "\x00", "", -1), "\v") | args := strings.Split(strings.Replace(payload, "\x00", "", -1), "\v") | ||||
| if len(args) != 2 { | if len(args) != 2 { | ||||
| break | |||||
| return | |||||
| } | } | ||||
| args[0] = strings.TrimLeft(args[0], "\x04") | args[0] = strings.TrimLeft(args[0], "\x04") | ||||
| _, _, err := com.ExecCmdBytes("env", args[0]+"="+args[1]) | _, _, err := com.ExecCmdBytes("env", args[0]+"="+args[1]) | ||||
| if err != nil { | if err != nil { | ||||
| log.Error(3, "env: %v", err) | log.Error(3, "env: %v", err) | ||||
| channel.Stderr().Write([]byte(err.Error())) | |||||
| break | |||||
| return | |||||
| } | } | ||||
| ok = true | |||||
| case "exec": | case "exec": | ||||
| os.Setenv("SSH_ORIGINAL_COMMAND", strings.TrimLeft(payload, "'(")) | |||||
| log.Info("Payload: %v", strings.TrimLeft(payload, "'(")) | |||||
| cmd := exec.Command("/Users/jiahuachen/Applications/Go/src/github.com/gogits/gogs/gogs", "serv", "key-"+keyId) | |||||
| cmd.Stdout = channel | |||||
| cmd.Stdin = channel | |||||
| cmd.Stderr = channel.Stderr() | |||||
| if err := cmd.Run(); err != nil { | |||||
| log.Error(3, "exec: %v", err) | |||||
| } else { | |||||
| ok = true | |||||
| cmdName := strings.TrimLeft(payload, "'()") | |||||
| os.Setenv("SSH_ORIGINAL_COMMAND", cmdName) | |||||
| log.Trace("Payload: %v", cmdName) | |||||
| cmd := exec.Command(setting.AppPath, "serv", "key-"+keyID) | |||||
| stdout, err := cmd.StdoutPipe() | |||||
| if err != nil { | |||||
| log.Error(3, "StdoutPipe: %v", err) | |||||
| return | |||||
| } | |||||
| stderr, err := cmd.StderrPipe() | |||||
| if err != nil { | |||||
| log.Error(3, "StderrPipe: %v", err) | |||||
| return | |||||
| } | |||||
| input, err := cmd.StdinPipe() | |||||
| if err != nil { | |||||
| log.Error(3, "StdinPipe: %v", err) | |||||
| return | |||||
| } | } | ||||
| go io.Copy(ch, stdout) | |||||
| go io.Copy(ch.Stderr(), stderr) | |||||
| go io.Copy(input, ch) | |||||
| if err = cmd.Start(); err != nil { | |||||
| log.Error(3, "Start: %v", err) | |||||
| return | |||||
| } else if err = cmd.Wait(); err != nil { | |||||
| log.Error(3, "Wait: %v", err) | |||||
| return | |||||
| } | |||||
| ch.SendRequest("exit-status", false, []byte{0, 0, 0, 0}) | |||||
| return | |||||
| default: | |||||
| } | } | ||||
| fmt.Println("Done:", ok) | |||||
| } | } | ||||
| fmt.Println("Done!!!") | |||||
| }(requests) | |||||
| }(reqs) | |||||
| } | } | ||||
| } | } | ||||
| func listen(config *ssh.ServerConfig, port string) { | |||||
| listener, err := net.Listen("tcp", "0.0.0.0:"+port) | |||||
| func listen(config *ssh.ServerConfig, port int) { | |||||
| listener, err := net.Listen("tcp", "0.0.0.0:"+com.ToStr(port)) | |||||
| if err != nil { | if err != nil { | ||||
| panic(err) | panic(err) | ||||
| } | } | ||||
| @@ -82,15 +112,17 @@ func listen(config *ssh.ServerConfig, port string) { | |||||
| // Once a ServerConfig has been configured, connections can be accepted. | // Once a ServerConfig has been configured, connections can be accepted. | ||||
| conn, err := listener.Accept() | conn, err := listener.Accept() | ||||
| if err != nil { | if err != nil { | ||||
| log.Error(3, "Fail to accept incoming connection: %v", err) | |||||
| log.Error(3, "Error accepting incoming connection: %v", err) | |||||
| continue | continue | ||||
| } | } | ||||
| // Before use, a handshake must be performed on the incoming net.Conn. | // Before use, a handshake must be performed on the incoming net.Conn. | ||||
| sConn, chans, reqs, err := ssh.NewServerConn(conn, config) | sConn, chans, reqs, err := ssh.NewServerConn(conn, config) | ||||
| if err != nil { | if err != nil { | ||||
| log.Error(3, "Fail to handshake: %v", err) | |||||
| log.Error(3, "Error on handshaking: %v", err) | |||||
| continue | continue | ||||
| } | } | ||||
| log.Trace("Connection from %s (%s)", sConn.RemoteAddr(), sConn.ClientVersion()) | |||||
| // The incoming Request channel must be serviced. | // The incoming Request channel must be serviced. | ||||
| go ssh.DiscardRequests(reqs) | go ssh.DiscardRequests(reqs) | ||||
| go handleServerConn(sConn.Permissions.Extensions["key-id"], chans) | go handleServerConn(sConn.Permissions.Extensions["key-id"], chans) | ||||
| @@ -98,21 +130,25 @@ func listen(config *ssh.ServerConfig, port string) { | |||||
| } | } | ||||
| // Listen starts a SSH server listens on given port. | // Listen starts a SSH server listens on given port. | ||||
| func Listen(port string) { | |||||
| func Listen(port int) { | |||||
| config := &ssh.ServerConfig{ | config := &ssh.ServerConfig{ | ||||
| PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { | PublicKeyCallback: func(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) { | ||||
| // keyCache[string(ssh.MarshalAuthorizedKey(key))] = 2 | |||||
| return &ssh.Permissions{Extensions: map[string]string{"key-id": "1"}}, nil | |||||
| pkey, err := models.SearchPublicKeyByContent(strings.TrimSpace(string(ssh.MarshalAuthorizedKey(key)))) | |||||
| if err != nil { | |||||
| log.Error(3, "SearchPublicKeyByContent: %v", err) | |||||
| return nil, err | |||||
| } | |||||
| return &ssh.Permissions{Extensions: map[string]string{"key-id": com.ToStr(pkey.ID)}}, nil | |||||
| }, | }, | ||||
| } | } | ||||
| privateBytes, err := ioutil.ReadFile("/Users/jiahuachen/.ssh/id_rsa") | |||||
| privateBytes, err := ioutil.ReadFile(filepath.Join(models.SSHPath, "id_rsa")) | |||||
| if err != nil { | if err != nil { | ||||
| panic("failed to load private key") | |||||
| panic("Fail to load private key") | |||||
| } | } | ||||
| private, err := ssh.ParsePrivateKey(privateBytes) | private, err := ssh.ParsePrivateKey(privateBytes) | ||||
| if err != nil { | if err != nil { | ||||
| panic("failed to parse private key") | |||||
| panic("Fail to parse private key") | |||||
| } | } | ||||
| config.AddHostKey(private) | config.AddHostKey(private) | ||||
| @@ -0,0 +1,7 @@ | |||||
| // +build !go1.4 | |||||
| package ssh | |||||
| func Listen(port int) { | |||||
| panic("Gogs requires Go 1.4 for starting a SSH server") | |||||
| } | |||||
| @@ -25,6 +25,7 @@ import ( | |||||
| "github.com/gogits/gogs/modules/mailer" | "github.com/gogits/gogs/modules/mailer" | ||||
| "github.com/gogits/gogs/modules/middleware" | "github.com/gogits/gogs/modules/middleware" | ||||
| "github.com/gogits/gogs/modules/setting" | "github.com/gogits/gogs/modules/setting" | ||||
| "github.com/gogits/gogs/modules/ssh" | |||||
| "github.com/gogits/gogs/modules/user" | "github.com/gogits/gogs/modules/user" | ||||
| ) | ) | ||||
| @@ -76,6 +77,11 @@ func GlobalInit() { | |||||
| log.Info("TiDB Supported") | log.Info("TiDB Supported") | ||||
| } | } | ||||
| checkRunMode() | checkRunMode() | ||||
| if setting.StartSSHServer { | |||||
| ssh.Listen(setting.SSHPort) | |||||
| log.Info("SSH server started on :%v", setting.SSHPort) | |||||
| } | |||||
| } | } | ||||
| func InstallInit(ctx *middleware.Context) { | func InstallInit(ctx *middleware.Context) { | ||||