|
- // Copyright (C) MongoDB, Inc. 2017-present.
- //
- // Licensed under the Apache License, Version 2.0 (the "License"); you may
- // not use this file except in compliance with the License. You may obtain
- // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
-
- package options // import "go.mongodb.org/mongo-driver/mongo/options"
-
- import (
- "bytes"
- "context"
- "crypto/tls"
- "crypto/x509"
- "encoding/pem"
- "errors"
- "fmt"
- "io/ioutil"
- "net"
- "strings"
- "time"
-
- "go.mongodb.org/mongo-driver/bson/bsoncodec"
- "go.mongodb.org/mongo-driver/event"
- "go.mongodb.org/mongo-driver/mongo/readconcern"
- "go.mongodb.org/mongo-driver/mongo/readpref"
- "go.mongodb.org/mongo-driver/mongo/writeconcern"
- "go.mongodb.org/mongo-driver/tag"
- "go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
- )
-
- // ContextDialer makes new network connections
- type ContextDialer interface {
- DialContext(ctx context.Context, network, address string) (net.Conn, error)
- }
-
- // Credential holds auth options.
- //
- // AuthMechanism indicates the mechanism to use for authentication.
- // Supported values include "SCRAM-SHA-256", "SCRAM-SHA-1", "MONGODB-CR", "PLAIN", "GSSAPI", and "MONGODB-X509".
- //
- // AuthMechanismProperties specifies additional configuration options which may be used by certain
- // authentication mechanisms. Supported properties are:
- // SERVICE_NAME: Specifies the name of the service. Defaults to mongodb.
- // CANONICALIZE_HOST_NAME: If true, tells the driver to canonicalize the given hostname. Defaults to false. This
- // property may not be used on Linux and Darwin systems and may not be used at the same time as SERVICE_HOST.
- // SERVICE_REALM: Specifies the realm of the service.
- // SERVICE_HOST: Specifies a hostname for GSSAPI authentication if it is different from the server's address. For
- // authentication mechanisms besides GSSAPI, this property is ignored.
- //
- // AuthSource specifies the database to authenticate against.
- //
- // Username specifies the username that will be authenticated.
- //
- // Password specifies the password used for authentication.
- //
- // PasswordSet specifies if the password is actually set, since an empty password is a valid password.
- type Credential struct {
- AuthMechanism string
- AuthMechanismProperties map[string]string
- AuthSource string
- Username string
- Password string
- PasswordSet bool
- }
-
- // ClientOptions represents all possible options to configure a client.
- type ClientOptions struct {
- AppName *string
- Auth *Credential
- ConnectTimeout *time.Duration
- Compressors []string
- Dialer ContextDialer
- HeartbeatInterval *time.Duration
- Hosts []string
- LocalThreshold *time.Duration
- MaxConnIdleTime *time.Duration
- MaxPoolSize *uint64
- MinPoolSize *uint64
- PoolMonitor *event.PoolMonitor
- Monitor *event.CommandMonitor
- ReadConcern *readconcern.ReadConcern
- ReadPreference *readpref.ReadPref
- Registry *bsoncodec.Registry
- ReplicaSet *string
- RetryWrites *bool
- RetryReads *bool
- ServerSelectionTimeout *time.Duration
- Direct *bool
- SocketTimeout *time.Duration
- TLSConfig *tls.Config
- WriteConcern *writeconcern.WriteConcern
- ZlibLevel *int
-
- err error
-
- // Adds an option for internal use only and should not be set. This option is deprecated and is
- // not part of the stability guarantee. It may be removed in the future.
- AuthenticateToAnything *bool
- }
-
- // Client creates a new ClientOptions instance.
- func Client() *ClientOptions {
- return new(ClientOptions)
- }
-
- // Validate validates the client options. This method will return the first error found.
- func (c *ClientOptions) Validate() error { return c.err }
-
- // ApplyURI parses the provided connection string and sets the values and options accordingly.
- //
- // Errors that occur in this method can be retrieved by calling Validate.
- //
- // If the URI contains ssl=true this method will overwrite TLSConfig, even if there aren't any other
- // tls options specified.
- func (c *ClientOptions) ApplyURI(uri string) *ClientOptions {
- if c.err != nil {
- return c
- }
-
- cs, err := connstring.Parse(uri)
- if err != nil {
- c.err = err
- return c
- }
-
- if cs.AppName != "" {
- c.AppName = &cs.AppName
- }
-
- if cs.AuthMechanism != "" || cs.AuthMechanismProperties != nil || cs.AuthSource != "" ||
- cs.Username != "" || cs.PasswordSet {
- c.Auth = &Credential{
- AuthMechanism: cs.AuthMechanism,
- AuthMechanismProperties: cs.AuthMechanismProperties,
- AuthSource: cs.AuthSource,
- Username: cs.Username,
- Password: cs.Password,
- PasswordSet: cs.PasswordSet,
- }
- }
-
- if cs.ConnectSet {
- direct := cs.Connect == connstring.SingleConnect
- c.Direct = &direct
- }
-
- if cs.ConnectTimeoutSet {
- c.ConnectTimeout = &cs.ConnectTimeout
- }
-
- if len(cs.Compressors) > 0 {
- c.Compressors = cs.Compressors
- }
-
- if cs.HeartbeatIntervalSet {
- c.HeartbeatInterval = &cs.HeartbeatInterval
- }
-
- c.Hosts = cs.Hosts
-
- if cs.LocalThresholdSet {
- c.LocalThreshold = &cs.LocalThreshold
- }
-
- if cs.MaxConnIdleTimeSet {
- c.MaxConnIdleTime = &cs.MaxConnIdleTime
- }
-
- if cs.MaxPoolSizeSet {
- c.MaxPoolSize = &cs.MaxPoolSize
- }
-
- if cs.MinPoolSizeSet {
- c.MinPoolSize = &cs.MinPoolSize
- }
-
- if cs.ReadConcernLevel != "" {
- c.ReadConcern = readconcern.New(readconcern.Level(cs.ReadConcernLevel))
- }
-
- if cs.ReadPreference != "" || len(cs.ReadPreferenceTagSets) > 0 || cs.MaxStalenessSet {
- opts := make([]readpref.Option, 0, 1)
-
- tagSets := tag.NewTagSetsFromMaps(cs.ReadPreferenceTagSets)
- if len(tagSets) > 0 {
- opts = append(opts, readpref.WithTagSets(tagSets...))
- }
-
- if cs.MaxStaleness != 0 {
- opts = append(opts, readpref.WithMaxStaleness(cs.MaxStaleness))
- }
-
- mode, err := readpref.ModeFromString(cs.ReadPreference)
- if err != nil {
- c.err = err
- return c
- }
-
- c.ReadPreference, c.err = readpref.New(mode, opts...)
- if c.err != nil {
- return c
- }
- }
-
- if cs.RetryWritesSet {
- c.RetryWrites = &cs.RetryWrites
- }
-
- if cs.ReplicaSet != "" {
- c.ReplicaSet = &cs.ReplicaSet
- }
-
- if cs.ServerSelectionTimeoutSet {
- c.ServerSelectionTimeout = &cs.ServerSelectionTimeout
- }
-
- if cs.SocketTimeoutSet {
- c.SocketTimeout = &cs.SocketTimeout
- }
-
- if cs.SSL {
- tlsConfig := new(tls.Config)
-
- if cs.SSLCaFileSet {
- c.err = addCACertFromFile(tlsConfig, cs.SSLCaFile)
- if c.err != nil {
- return c
- }
- }
-
- if cs.SSLInsecure {
- tlsConfig.InsecureSkipVerify = true
- }
-
- if cs.SSLClientCertificateKeyFileSet {
- var keyPasswd string
- if cs.SSLClientCertificateKeyPasswordSet && cs.SSLClientCertificateKeyPassword != nil {
- keyPasswd = cs.SSLClientCertificateKeyPassword()
- }
- s, err := addClientCertFromFile(tlsConfig, cs.SSLClientCertificateKeyFile, keyPasswd)
- if err != nil {
- c.err = err
- return c
- }
-
- // If a username wasn't specified, add one from the certificate.
- if c.Auth != nil && strings.ToLower(c.Auth.AuthMechanism) == "mongodb-x509" && c.Auth.Username == "" {
- // The Go x509 package gives the subject with the pairs in reverse order that we want.
- pairs := strings.Split(s, ",")
- for left, right := 0, len(pairs)-1; left < right; left, right = left+1, right-1 {
- pairs[left], pairs[right] = pairs[right], pairs[left]
- }
- c.Auth.Username = strings.Join(pairs, ",")
- }
- }
-
- c.TLSConfig = tlsConfig
- }
-
- if cs.JSet || cs.WString != "" || cs.WNumberSet || cs.WTimeoutSet {
- opts := make([]writeconcern.Option, 0, 1)
-
- if len(cs.WString) > 0 {
- opts = append(opts, writeconcern.WTagSet(cs.WString))
- } else if cs.WNumberSet {
- opts = append(opts, writeconcern.W(cs.WNumber))
- }
-
- if cs.JSet {
- opts = append(opts, writeconcern.J(cs.J))
- }
-
- if cs.WTimeoutSet {
- opts = append(opts, writeconcern.WTimeout(cs.WTimeout))
- }
-
- c.WriteConcern = writeconcern.New(opts...)
- }
-
- if cs.ZlibLevelSet {
- c.ZlibLevel = &cs.ZlibLevel
- }
-
- return c
- }
-
- // SetAppName specifies the client application name. This value is used by MongoDB when it logs
- // connection information and profile information, such as slow queries.
- func (c *ClientOptions) SetAppName(s string) *ClientOptions {
- c.AppName = &s
- return c
- }
-
- // SetAuth sets the authentication options.
- func (c *ClientOptions) SetAuth(auth Credential) *ClientOptions {
- c.Auth = &auth
- return c
- }
-
- // SetCompressors sets the compressors that can be used when communicating with a server.
- func (c *ClientOptions) SetCompressors(comps []string) *ClientOptions {
- c.Compressors = comps
-
- return c
- }
-
- // SetConnectTimeout specifies the timeout for an initial connection to a server.
- // If a custom Dialer is used, this method won't be set and the user is
- // responsible for setting the ConnectTimeout for connections on the dialer
- // themselves.
- func (c *ClientOptions) SetConnectTimeout(d time.Duration) *ClientOptions {
- c.ConnectTimeout = &d
- return c
- }
-
- // SetDialer specifies a custom dialer used to dial new connections to a server.
- // If a custom dialer is not set, a net.Dialer with a 300 second keepalive time will be used by default.
- func (c *ClientOptions) SetDialer(d ContextDialer) *ClientOptions {
- c.Dialer = d
- return c
- }
-
- // SetDirect specifies whether the driver should connect directly to the server instead of
- // auto-discovering other servers in the cluster.
- func (c *ClientOptions) SetDirect(b bool) *ClientOptions {
- c.Direct = &b
- return c
- }
-
- // SetHeartbeatInterval specifies the interval to wait between server monitoring checks.
- func (c *ClientOptions) SetHeartbeatInterval(d time.Duration) *ClientOptions {
- c.HeartbeatInterval = &d
- return c
- }
-
- // SetHosts specifies the initial list of addresses from which to discover the rest of the cluster.
- func (c *ClientOptions) SetHosts(s []string) *ClientOptions {
- c.Hosts = s
- return c
- }
-
- // SetLocalThreshold specifies how far to distribute queries, beyond the server with the fastest
- // round-trip time. If a server's roundtrip time is more than LocalThreshold slower than the
- // the fastest, the driver will not send queries to that server.
- func (c *ClientOptions) SetLocalThreshold(d time.Duration) *ClientOptions {
- c.LocalThreshold = &d
- return c
- }
-
- // SetMaxConnIdleTime specifies the maximum number of milliseconds that a connection can remain idle
- // in a connection pool before being removed and closed.
- func (c *ClientOptions) SetMaxConnIdleTime(d time.Duration) *ClientOptions {
- c.MaxConnIdleTime = &d
- return c
- }
-
- // SetMaxPoolSize specifies the max size of a server's connection pool.
- func (c *ClientOptions) SetMaxPoolSize(u uint64) *ClientOptions {
- c.MaxPoolSize = &u
- return c
- }
-
- // SetMinPoolSize specifies the min size of a server's connection pool.
- func (c *ClientOptions) SetMinPoolSize(u uint64) *ClientOptions {
- c.MinPoolSize = &u
- return c
- }
-
- // SetPoolMonitor specifies the PoolMonitor for a server's connection pool.
- func (c *ClientOptions) SetPoolMonitor(m *event.PoolMonitor) *ClientOptions {
- c.PoolMonitor = m
- return c
- }
-
- // SetMonitor specifies a command monitor used to see commands for a client.
- func (c *ClientOptions) SetMonitor(m *event.CommandMonitor) *ClientOptions {
- c.Monitor = m
- return c
- }
-
- // SetReadConcern specifies the read concern.
- func (c *ClientOptions) SetReadConcern(rc *readconcern.ReadConcern) *ClientOptions {
- c.ReadConcern = rc
-
- return c
- }
-
- // SetReadPreference specifies the read preference.
- func (c *ClientOptions) SetReadPreference(rp *readpref.ReadPref) *ClientOptions {
- c.ReadPreference = rp
-
- return c
- }
-
- // SetRegistry specifies the bsoncodec.Registry.
- func (c *ClientOptions) SetRegistry(registry *bsoncodec.Registry) *ClientOptions {
- c.Registry = registry
- return c
- }
-
- // SetReplicaSet specifies the name of the replica set of the cluster.
- func (c *ClientOptions) SetReplicaSet(s string) *ClientOptions {
- c.ReplicaSet = &s
- return c
- }
-
- // SetRetryWrites specifies whether the client has retryable writes enabled.
- func (c *ClientOptions) SetRetryWrites(b bool) *ClientOptions {
- c.RetryWrites = &b
-
- return c
- }
-
- // SetServerSelectionTimeout specifies a timeout in milliseconds to block for server selection.
- func (c *ClientOptions) SetServerSelectionTimeout(d time.Duration) *ClientOptions {
- c.ServerSelectionTimeout = &d
- return c
- }
-
- // SetSocketTimeout specifies the time in milliseconds to attempt to send or receive on a socket
- // before the attempt times out.
- func (c *ClientOptions) SetSocketTimeout(d time.Duration) *ClientOptions {
- c.SocketTimeout = &d
- return c
- }
-
- // SetTLSConfig sets the tls.Config.
- func (c *ClientOptions) SetTLSConfig(cfg *tls.Config) *ClientOptions {
- c.TLSConfig = cfg
- return c
- }
-
- // SetWriteConcern sets the write concern.
- func (c *ClientOptions) SetWriteConcern(wc *writeconcern.WriteConcern) *ClientOptions {
- c.WriteConcern = wc
-
- return c
- }
-
- // SetZlibLevel sets the level for the zlib compressor.
- func (c *ClientOptions) SetZlibLevel(level int) *ClientOptions {
- c.ZlibLevel = &level
-
- return c
- }
-
- // MergeClientOptions combines the given connstring and *ClientOptions into a single *ClientOptions in a last one wins
- // fashion. The given connstring will be used for the default options, which can be overwritten using the given
- // *ClientOptions.
- func MergeClientOptions(opts ...*ClientOptions) *ClientOptions {
- c := Client()
-
- for _, opt := range opts {
- if opt == nil {
- continue
- }
-
- if opt.Dialer != nil {
- c.Dialer = opt.Dialer
- }
- if opt.AppName != nil {
- c.AppName = opt.AppName
- }
- if opt.Auth != nil {
- c.Auth = opt.Auth
- }
- if opt.AuthenticateToAnything != nil {
- c.AuthenticateToAnything = opt.AuthenticateToAnything
- }
- if opt.Compressors != nil {
- c.Compressors = opt.Compressors
- }
- if opt.ConnectTimeout != nil {
- c.ConnectTimeout = opt.ConnectTimeout
- }
- if opt.HeartbeatInterval != nil {
- c.HeartbeatInterval = opt.HeartbeatInterval
- }
- if len(opt.Hosts) > 0 {
- c.Hosts = opt.Hosts
- }
- if opt.LocalThreshold != nil {
- c.LocalThreshold = opt.LocalThreshold
- }
- if opt.MaxConnIdleTime != nil {
- c.MaxConnIdleTime = opt.MaxConnIdleTime
- }
- if opt.MaxPoolSize != nil {
- c.MaxPoolSize = opt.MaxPoolSize
- }
- if opt.MinPoolSize != nil {
- c.MinPoolSize = opt.MinPoolSize
- }
- if opt.PoolMonitor != nil {
- c.PoolMonitor = opt.PoolMonitor
- }
- if opt.Monitor != nil {
- c.Monitor = opt.Monitor
- }
- if opt.ReadConcern != nil {
- c.ReadConcern = opt.ReadConcern
- }
- if opt.ReadPreference != nil {
- c.ReadPreference = opt.ReadPreference
- }
- if opt.Registry != nil {
- c.Registry = opt.Registry
- }
- if opt.ReplicaSet != nil {
- c.ReplicaSet = opt.ReplicaSet
- }
- if opt.RetryWrites != nil {
- c.RetryWrites = opt.RetryWrites
- }
- if opt.RetryReads != nil {
- c.RetryReads = opt.RetryReads
- }
- if opt.ServerSelectionTimeout != nil {
- c.ServerSelectionTimeout = opt.ServerSelectionTimeout
- }
- if opt.Direct != nil {
- c.Direct = opt.Direct
- }
- if opt.SocketTimeout != nil {
- c.SocketTimeout = opt.SocketTimeout
- }
- if opt.TLSConfig != nil {
- c.TLSConfig = opt.TLSConfig
- }
- if opt.WriteConcern != nil {
- c.WriteConcern = opt.WriteConcern
- }
- if opt.ZlibLevel != nil {
- c.ZlibLevel = opt.ZlibLevel
- }
- if opt.err != nil {
- c.err = opt.err
- }
-
- }
-
- return c
- }
-
- // addCACertFromFile adds a root CA certificate to the configuration given a path
- // to the containing file.
- func addCACertFromFile(cfg *tls.Config, file string) error {
- data, err := ioutil.ReadFile(file)
- if err != nil {
- return err
- }
-
- certBytes, err := loadCert(data)
- if err != nil {
- return err
- }
-
- cert, err := x509.ParseCertificate(certBytes)
- if err != nil {
- return err
- }
-
- if cfg.RootCAs == nil {
- cfg.RootCAs = x509.NewCertPool()
- }
-
- cfg.RootCAs.AddCert(cert)
-
- return nil
- }
-
- func loadCert(data []byte) ([]byte, error) {
- var certBlock *pem.Block
-
- for certBlock == nil {
- if data == nil || len(data) == 0 {
- return nil, errors.New(".pem file must have both a CERTIFICATE and an RSA PRIVATE KEY section")
- }
-
- block, rest := pem.Decode(data)
- if block == nil {
- return nil, errors.New("invalid .pem file")
- }
-
- switch block.Type {
- case "CERTIFICATE":
- if certBlock != nil {
- return nil, errors.New("multiple CERTIFICATE sections in .pem file")
- }
-
- certBlock = block
- }
-
- data = rest
- }
-
- return certBlock.Bytes, nil
- }
-
- // addClientCertFromFile adds a client certificate to the configuration given a path to the
- // containing file and returns the certificate's subject name.
- func addClientCertFromFile(cfg *tls.Config, clientFile, keyPasswd string) (string, error) {
- data, err := ioutil.ReadFile(clientFile)
- if err != nil {
- return "", err
- }
-
- var currentBlock *pem.Block
- var certBlock, certDecodedBlock, keyBlock []byte
-
- remaining := data
- start := 0
- for {
- currentBlock, remaining = pem.Decode(remaining)
- if currentBlock == nil {
- break
- }
-
- if currentBlock.Type == "CERTIFICATE" {
- certBlock = data[start : len(data)-len(remaining)]
- certDecodedBlock = currentBlock.Bytes
- start += len(certBlock)
- } else if strings.HasSuffix(currentBlock.Type, "PRIVATE KEY") {
- if keyPasswd != "" && x509.IsEncryptedPEMBlock(currentBlock) {
- var encoded bytes.Buffer
- buf, err := x509.DecryptPEMBlock(currentBlock, []byte(keyPasswd))
- if err != nil {
- return "", err
- }
-
- pem.Encode(&encoded, &pem.Block{Type: currentBlock.Type, Bytes: buf})
- keyBlock = encoded.Bytes()
- start = len(data) - len(remaining)
- } else {
- keyBlock = data[start : len(data)-len(remaining)]
- start += len(keyBlock)
- }
- }
- }
- if len(certBlock) == 0 {
- return "", fmt.Errorf("failed to find CERTIFICATE")
- }
- if len(keyBlock) == 0 {
- return "", fmt.Errorf("failed to find PRIVATE KEY")
- }
-
- cert, err := tls.X509KeyPair(certBlock, keyBlock)
- if err != nil {
- return "", err
- }
-
- cfg.Certificates = append(cfg.Certificates, cert)
-
- // The documentation for the tls.X509KeyPair indicates that the Leaf certificate is not
- // retained.
- crt, err := x509.ParseCertificate(certDecodedBlock)
- if err != nil {
- return "", err
- }
-
- return x509CertSubject(crt), nil
- }
|