diff options
-rw-r--r-- | sms-service/doc/coverage.html | 861 | ||||
-rw-r--r-- | sms-service/doc/coverage.md (renamed from sms-service/src/sms/coverage.md) | 0 | ||||
-rw-r--r-- | sms-service/src/quorumclient/quorumclient.go | 22 | ||||
-rw-r--r-- | sms-service/src/sms/Gopkg.lock | 2 | ||||
-rw-r--r-- | sms-service/src/sms/auth/auth.go | 82 | ||||
-rw-r--r-- | sms-service/src/sms/backend/backend.go | 3 | ||||
-rw-r--r-- | sms-service/src/sms/backend/vault.go | 105 | ||||
-rw-r--r-- | sms-service/src/sms/backend/vault_test.go | 38 | ||||
-rw-r--r-- | sms-service/src/sms/handler/handler.go | 54 | ||||
-rw-r--r-- | sms-service/src/sms/handler/handler_test.go | 45 | ||||
-rw-r--r-- | sms-service/src/sms/log/logger.go | 73 |
11 files changed, 809 insertions, 476 deletions
diff --git a/sms-service/doc/coverage.html b/sms-service/doc/coverage.html index d03ddde..39ee191 100644 --- a/sms-service/doc/coverage.html +++ b/sms-service/doc/coverage.html @@ -54,19 +54,19 @@ <div id="nav"> <select id="files"> - <option value="file0">sms/auth/auth.go (17.6%)</option> + <option value="file0">sms/auth/auth.go (76.1%)</option> - <option value="file1">sms/backend/backend.go (66.7%)</option> + <option value="file1">sms/backend/backend.go (80.0%)</option> - <option value="file2">sms/backend/vault.go (60.5%)</option> + <option value="file2">sms/backend/vault.go (72.5%)</option> - <option value="file3">sms/config/config.go (90.9%)</option> + <option value="file3">sms/config/config.go (78.6%)</option> - <option value="file4">sms/handler/handler.go (55.1%)</option> + <option value="file4">sms/handler/handler.go (63.0%)</option> - <option value="file5">sms/log/logger.go (31.2%)</option> + <option value="file5">sms/log/logger.go (65.6%)</option> - <option value="file6">sms/sms.go (82.6%)</option> + <option value="file6">sms/sms.go (77.8%)</option> </select> </div> @@ -109,6 +109,7 @@ package auth import ( "bytes" + "crypto" "crypto/tls" "crypto/x509" "encoding/base64" @@ -119,63 +120,63 @@ import ( smslogger "sms/log" ) -var tlsConfig *tls.Config - // GetTLSConfig initializes a tlsConfig using the CA's certificate // This config is then used to enable the server for mutual TLS func GetTLSConfig(caCertFile string) (*tls.Config, error) <span class="cov10" title="3">{ + // Initialize tlsConfig once - if tlsConfig == nil </span><span class="cov10" title="3">{ - caCert, err := ioutil.ReadFile(caCertFile) + caCert, err := ioutil.ReadFile(caCertFile) - if err != nil </span><span class="cov1" title="1">{ - return nil, err - }</span> + if err != nil </span><span class="cov1" title="1">{ + return nil, err + }</span> - <span class="cov6" title="2">caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + <span class="cov6" title="2">caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) - tlsConfig = &tls.Config{ - ClientAuth: tls.RequireAndVerifyClientCert, - ClientCAs: caCertPool, - MinVersion: tls.VersionTLS12, - } - tlsConfig.BuildNameToCertificate()</span> + tlsConfig := &tls.Config{ + // Change to RequireAndVerify once we have mandatory certs + ClientAuth: tls.VerifyClientCertIfGiven, + ClientCAs: caCertPool, + MinVersion: tls.VersionTLS12, } - <span class="cov6" title="2">return tlsConfig, nil</span> + tlsConfig.BuildNameToCertificate() + return tlsConfig, nil</span> } // GeneratePGPKeyPair produces a PGP key pair and returns // two things: // A base64 encoded form of the public part of the entity // A base64 encoded form of the private key -func GeneratePGPKeyPair() (string, string, error) <span class="cov0" title="0">{ +func GeneratePGPKeyPair() (string, string, error) <span class="cov10" title="3">{ + var entity *openpgp.Entity - entity, err := openpgp.NewEntity("aaf.sms.init", "PGP Key for unsealing", "", nil) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + config := &packet.Config{ + DefaultHash: crypto.SHA256, + } + + entity, err := openpgp.NewEntity("aaf.sms.init", "PGP Key for unsealing", "", config) + if smslogger.CheckError(err, "Create Entity") != nil </span><span class="cov0" title="0">{ return "", "", err }</span> // Sign the identity in the entity - <span class="cov0" title="0">for _, id := range entity.Identities </span><span class="cov0" title="0">{ + <span class="cov10" title="3">for _, id := range entity.Identities </span><span class="cov10" title="3">{ err = id.SelfSignature.SignUserId(id.UserId.Id, entity.PrimaryKey, entity.PrivateKey, nil) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Sign Entity") != nil </span><span class="cov0" title="0">{ return "", "", err }</span> } // Sign the subkey in the entity - <span class="cov0" title="0">for _, subkey := range entity.Subkeys </span><span class="cov0" title="0">{ + <span class="cov10" title="3">for _, subkey := range entity.Subkeys </span><span class="cov10" title="3">{ err := subkey.Sig.SignKey(subkey.PublicKey, entity.PrivateKey, nil) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Sign Subkey") != nil </span><span class="cov0" title="0">{ return "", "", err }</span> } - <span class="cov0" title="0">buffer := new(bytes.Buffer) + <span class="cov10" title="3">buffer := new(bytes.Buffer) entity.Serialize(buffer) pbkey := base64.StdEncoding.EncodeToString(buffer.Bytes()) @@ -186,40 +187,96 @@ func GeneratePGPKeyPair() (string, string, error) <span class="cov0" title="0">{ return pbkey, prkey, nil</span> } -// DecryptPGPBytes decrypts a PGP encoded input string and returns +// EncryptPGPString takes data and a public key and encrypts using that +// public key +func EncryptPGPString(data string, pbKey string) (string, error) <span class="cov6" title="2">{ + + pbKeyBytes, err := base64.StdEncoding.DecodeString(pbKey) + if smslogger.CheckError(err, "Decoding Base64 Public Key") != nil </span><span class="cov0" title="0">{ + return "", err + }</span> + + <span class="cov6" title="2">dataBytes := []byte(data) + + pbEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(pbKeyBytes))) + if smslogger.CheckError(err, "Reading entity from PGP key") != nil </span><span class="cov0" title="0">{ + return "", err + }</span> + + // encrypt string + <span class="cov6" title="2">buf := new(bytes.Buffer) + out, err := openpgp.Encrypt(buf, []*openpgp.Entity{pbEntity}, nil, nil, nil) + if smslogger.CheckError(err, "Creating Encryption Pipe") != nil </span><span class="cov0" title="0">{ + return "", err + }</span> + + <span class="cov6" title="2">_, err = out.Write(dataBytes) + if smslogger.CheckError(err, "Writing to Encryption Pipe") != nil </span><span class="cov0" title="0">{ + return "", err + }</span> + + <span class="cov6" title="2">err = out.Close() + if smslogger.CheckError(err, "Closing Encryption Pipe") != nil </span><span class="cov0" title="0">{ + return "", err + }</span> + + <span class="cov6" title="2">crp := base64.StdEncoding.EncodeToString(buf.Bytes()) + return crp, nil</span> +} + +// DecryptPGPString decrypts a PGP encoded input string and returns // a base64 representation of the decoded string -func DecryptPGPBytes(data string, prKey string) (string, error) <span class="cov0" title="0">{ +func DecryptPGPString(data string, prKey string) (string, error) <span class="cov1" title="1">{ + // Convert private key to bytes from base64 prKeyBytes, err := base64.StdEncoding.DecodeString(prKey) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError("Error Decoding base64 private key: " + err.Error()) + if smslogger.CheckError(err, "Decoding Base64 Private Key") != nil </span><span class="cov0" title="0">{ return "", err }</span> - <span class="cov0" title="0">dataBytes, err := base64.StdEncoding.DecodeString(data) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError("Error Decoding base64 data: " + err.Error()) + <span class="cov1" title="1">dataBytes, err := base64.StdEncoding.DecodeString(data) + if smslogger.CheckError(err, "Decoding base64 data") != nil </span><span class="cov0" title="0">{ return "", err }</span> - <span class="cov0" title="0">prEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(prKeyBytes))) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError("Error reading entity from PGP key: " + err.Error()) + <span class="cov1" title="1">prEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(prKeyBytes))) + if smslogger.CheckError(err, "Read Entity") != nil </span><span class="cov0" title="0">{ return "", err }</span> - <span class="cov0" title="0">prEntityList := &openpgp.EntityList{prEntity} + <span class="cov1" title="1">prEntityList := &openpgp.EntityList{prEntity} message, err := openpgp.ReadMessage(bytes.NewBuffer(dataBytes), prEntityList, nil, nil) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError("Error Decrypting message: " + err.Error()) + if smslogger.CheckError(err, "Decrypting Message") != nil </span><span class="cov0" title="0">{ return "", err }</span> - <span class="cov0" title="0">var retBuf bytes.Buffer + <span class="cov1" title="1">var retBuf bytes.Buffer retBuf.ReadFrom(message.UnverifiedBody) return retBuf.String(), nil</span> } + +// ReadFromFile reads a file and loads the PGP key into +// a string +func ReadFromFile(fileName string) (string, error) <span class="cov6" title="2">{ + + data, err := ioutil.ReadFile(fileName) + if smslogger.CheckError(err, "Read from file") != nil </span><span class="cov0" title="0">{ + return "", err + }</span> + <span class="cov6" title="2">return string(data), nil</span> +} + +// WriteToFile writes a PGP key into a file. +// It will truncate the file if it exists +func WriteToFile(data string, fileName string) error <span class="cov0" title="0">{ + + err := ioutil.WriteFile(fileName, []byte(data), 0600) + if smslogger.CheckError(err, "Write to file") != nil </span><span class="cov0" title="0">{ + return err + }</span> + <span class="cov0" title="0">return nil</span> +} </pre> <pre class="file" id="file1" style="display: none">/* @@ -264,6 +321,7 @@ type SecretBackend interface { Init() error GetStatus() (bool, error) Unseal(shard string) error + RegisterQuorum(pgpkey string) (string, error) GetSecret(dom string, sec string) (Secret, error) ListSecret(dom string) ([]string, error) @@ -276,19 +334,18 @@ type SecretBackend interface { } // InitSecretBackend returns an interface implementation -func InitSecretBackend() (SecretBackend, error) <span class="cov10" title="2">{ +func InitSecretBackend() (SecretBackend, error) <span class="cov8" title="1">{ backendImpl := &Vault{ - vaultAddress: smsconfig.SMSConfig.VaultAddress, + vaultAddress: smsconfig.SMSConfig.BackendAddress, vaultToken: smsconfig.SMSConfig.VaultToken, } err := backendImpl.Init() - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "InitSecretBackend") != nil </span><span class="cov0" title="0">{ return nil, err }</span> - <span class="cov10" title="2">return backendImpl, nil</span> + <span class="cov8" title="1">return backendImpl, nil</span> } // LoginBackend Interface that will be implemented for various login backends @@ -330,68 +387,108 @@ import ( // Vault is the main Struct used in Backend to initialize the struct type Vault struct { sync.Mutex - engineType string - initRoleDone bool - policyName string - roleID string - secretID string - vaultAddress string - vaultClient *vaultapi.Client - vaultMount string - vaultTempTokenTTL time.Time - vaultToken string - unsealShards []string - rootToken string - pgpPub string - pgpPr string + initRoleDone bool + policyName string + roleID string + secretID string + vaultAddress string + vaultClient *vaultapi.Client + vaultMountPrefix string + internalDomain string + internalDomainMounted bool + vaultTempTokenTTL time.Time + vaultToken string + shards []string + prkey string } -// Init will initialize the vault connection -// It will also create the initial policy if it does not exist -// TODO: Check to see if we need to wait for vault to be running -func (v *Vault) Init() error <span class="cov4" title="3">{ +// initVaultClient will create the initial +// Vault strcuture and populate it with the +// right values and it will also create +// a vault client +func (v *Vault) initVaultClient() error <span class="cov6" title="11">{ + vaultCFG := vaultapi.DefaultConfig() vaultCFG.Address = v.vaultAddress client, err := vaultapi.NewClient(vaultCFG) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) - return errors.New("Unable to create new vault client") + if smslogger.CheckError(err, "Create new vault client") != nil </span><span class="cov0" title="0">{ + return err }</span> - <span class="cov4" title="3">v.engineType = "kv" - v.initRoleDone = false + <span class="cov6" title="11">v.initRoleDone = false v.policyName = "smsvaultpolicy" v.vaultClient = client - v.vaultMount = "sms" + v.vaultMountPrefix = "sms" + v.internalDomain = "smsinternaldomain" + v.internalDomainMounted = false + v.prkey = "" + return nil</span> +} - err = v.initRole() - if err != nil </span><span class="cov2" title="2">{ - smslogger.WriteError(err.Error()) +// Init will initialize the vault connection +// It will also initialize vault if it is not +// already initialized. +// The initial policy will also be created +func (v *Vault) Init() error <span class="cov1" title="1">{ + + v.initVaultClient() + // Initialize vault if it is not already + // Returns immediately if it is initialized + v.initializeVault() + + err := v.initRole() + if smslogger.CheckError(err, "InitRole First Attempt") != nil </span><span class="cov0" title="0">{ smslogger.WriteInfo("InitRole will try again later") }</span> - <span class="cov4" title="3">return nil</span> + <span class="cov1" title="1">return nil</span> } // GetStatus returns the current seal status of vault -func (v *Vault) GetStatus() (bool, error) <span class="cov4" title="3">{ +func (v *Vault) GetStatus() (bool, error) <span class="cov3" title="3">{ + sys := v.vaultClient.Sys() sealStatus, err := sys.SealStatus() - if err != nil </span><span class="cov1" title="1">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Getting Status") != nil </span><span class="cov0" title="0">{ return false, errors.New("Error getting status") }</span> - <span class="cov2" title="2">return sealStatus.Sealed, nil</span> + <span class="cov3" title="3">return sealStatus.Sealed, nil</span> +} + +// RegisterQuorum registers the PGP public key for a quorum client +// We will return a shard to the client that is registering +func (v *Vault) RegisterQuorum(pgpkey string) (string, error) <span class="cov0" title="0">{ + + v.Lock() + defer v.Unlock() + + if v.shards == nil </span><span class="cov0" title="0">{ + smslogger.WriteError("Invalid operation in RegisterQuorum") + return "", errors.New("Invalid operation") + }</span> + // Pop the slice + <span class="cov0" title="0">var sh string + sh, v.shards = v.shards[len(v.shards)-1], v.shards[:len(v.shards)-1] + if len(v.shards) == 0 </span><span class="cov0" title="0">{ + v.shards = nil + }</span> + + // Decrypt with SMS pgp Key + <span class="cov0" title="0">sh, _ = smsauth.DecryptPGPString(sh, v.prkey) + // Encrypt with Quorum client pgp key + sh, _ = smsauth.EncryptPGPString(sh, pgpkey) + + return sh, nil</span> } // Unseal is a passthrough API that allows any // unseal or initialization processes for the backend func (v *Vault) Unseal(shard string) error <span class="cov0" title="0">{ + sys := v.vaultClient.Sys() _, err := sys.Unseal(shard) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Unseal Operation") != nil </span><span class="cov0" title="0">{ return errors.New("Unable to execute unseal operation with specified shard") }</span> @@ -401,80 +498,140 @@ func (v *Vault) Unseal(shard string) error <span class="cov0" title="0">{ // GetSecret returns a secret mounted on a particular domain name // The secret itself is referenced via its name which translates to // a mount path in vault -func (v *Vault) GetSecret(dom string, name string) (Secret, error) <span class="cov6" title="6">{ +func (v *Vault) GetSecret(dom string, name string) (Secret, error) <span class="cov5" title="7">{ + err := v.checkToken() - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Tocken Check") != nil </span><span class="cov0" title="0">{ return Secret{}, errors.New("Token check failed") }</span> - <span class="cov6" title="6">dom = v.vaultMount + "/" + dom + <span class="cov5" title="7">dom = v.vaultMountPrefix + "/" + dom sec, err := v.vaultClient.Logical().Read(dom + "/" + name) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Read Secret") != nil </span><span class="cov0" title="0">{ return Secret{}, errors.New("Unable to read Secret at provided path") }</span> // sec and err are nil in the case where a path does not exist - <span class="cov6" title="6">if sec == nil </span><span class="cov0" title="0">{ + <span class="cov5" title="7">if sec == nil </span><span class="cov0" title="0">{ smslogger.WriteWarn("Vault read was empty. Invalid Path") return Secret{}, errors.New("Secret not found at the provided path") }</span> - <span class="cov6" title="6">return Secret{Name: name, Values: sec.Data}, nil</span> + <span class="cov5" title="7">return Secret{Name: name, Values: sec.Data}, nil</span> } // ListSecret returns a list of secret names on a particular domain // The values of the secret are not returned -func (v *Vault) ListSecret(dom string) ([]string, error) <span class="cov2" title="2">{ +func (v *Vault) ListSecret(dom string) ([]string, error) <span class="cov3" title="3">{ + err := v.checkToken() - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Token Check") != nil </span><span class="cov0" title="0">{ return nil, errors.New("Token check failed") }</span> - <span class="cov2" title="2">dom = v.vaultMount + "/" + dom + <span class="cov3" title="3">dom = v.vaultMountPrefix + "/" + dom sec, err := v.vaultClient.Logical().List(dom) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Read Secret") != nil </span><span class="cov0" title="0">{ return nil, errors.New("Unable to read Secret at provided path") }</span> // sec and err are nil in the case where a path does not exist - <span class="cov2" title="2">if sec == nil </span><span class="cov0" title="0">{ + <span class="cov3" title="3">if sec == nil </span><span class="cov0" title="0">{ smslogger.WriteWarn("Vaultclient returned empty data") return nil, errors.New("Secret not found at the provided path") }</span> - <span class="cov2" title="2">val, ok := sec.Data["keys"].([]interface{}) + <span class="cov3" title="3">val, ok := sec.Data["keys"].([]interface{}) if !ok </span><span class="cov0" title="0">{ smslogger.WriteError("Secret not found at the provided path") return nil, errors.New("Secret not found at the provided path") }</span> - <span class="cov2" title="2">retval := make([]string, len(val)) - for i, v := range val </span><span class="cov6" title="6">{ + <span class="cov3" title="3">retval := make([]string, len(val)) + for i, v := range val </span><span class="cov5" title="7">{ retval[i] = fmt.Sprint(v) }</span> - <span class="cov2" title="2">return retval, nil</span> + <span class="cov3" title="3">return retval, nil</span> +} + +// Mounts the internal Domain if its not already mounted +func (v *Vault) mountInternalDomain(name string) error <span class="cov5" title="8">{ + + if v.internalDomainMounted </span><span class="cov1" title="1">{ + return nil + }</span> + + <span class="cov5" title="7">name = strings.TrimSpace(name) + mountPath := v.vaultMountPrefix + "/" + name + mountInput := &vaultapi.MountInput{ + Type: "kv", + Description: "Mount point for domain: " + name, + Local: false, + SealWrap: false, + Config: vaultapi.MountConfigInput{}, + } + + err := v.vaultClient.Sys().Mount(mountPath, mountInput) + if smslogger.CheckError(err, "Mount internal Domain") != nil </span><span class="cov1" title="1">{ + if strings.Contains(err.Error(), "existing mount") </span><span class="cov1" title="1">{ + // It is already mounted + v.internalDomainMounted = true + return nil + }</span> + // Ran into some other error mounting it. + <span class="cov0" title="0">return errors.New("Unable to mount internal Domain")</span> + } + + <span class="cov5" title="6">v.internalDomainMounted = true + return nil</span> +} + +// Stores the UUID created for secretdomain in vault +// under v.vaultMountPrefix / smsinternal domain +func (v *Vault) storeUUID(uuid string, name string) error <span class="cov5" title="8">{ + + // Check if token is still valid + err := v.checkToken() + if smslogger.CheckError(err, "Token Check") != nil </span><span class="cov0" title="0">{ + return errors.New("Token Check failed") + }</span> + + <span class="cov5" title="8">err = v.mountInternalDomain(v.internalDomain) + if smslogger.CheckError(err, "Mount Internal Domain") != nil </span><span class="cov0" title="0">{ + return err + }</span> + + <span class="cov5" title="8">secret := Secret{ + Name: name, + Values: map[string]interface{}{ + "uuid": uuid, + }, + } + + err = v.CreateSecret(v.internalDomain, secret) + if smslogger.CheckError(err, "Write UUID to domain") != nil </span><span class="cov0" title="0">{ + return err + }</span> + + <span class="cov5" title="8">return nil</span> } // CreateSecretDomain mounts the kv backend on a path with the given name -func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) <span class="cov2" title="2">{ +func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) <span class="cov5" title="8">{ + // Check if token is still valid err := v.checkToken() - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Token Check") != nil </span><span class="cov0" title="0">{ return SecretDomain{}, errors.New("Token Check failed") }</span> - <span class="cov2" title="2">name = strings.TrimSpace(name) - mountPath := v.vaultMount + "/" + name + <span class="cov5" title="8">name = strings.TrimSpace(name) + mountPath := v.vaultMountPrefix + "/" + name mountInput := &vaultapi.MountInput{ - Type: v.engineType, + Type: "kv", Description: "Mount point for domain: " + name, Local: false, SealWrap: false, @@ -482,171 +639,212 @@ func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) <span clas } err = v.vaultClient.Sys().Mount(mountPath, mountInput) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Create Domain") != nil </span><span class="cov0" title="0">{ return SecretDomain{}, errors.New("Unable to create Secret Domain") }</span> - <span class="cov2" title="2">uuid, _ := uuid.GenerateUUID() - return SecretDomain{uuid, name}, nil</span> + <span class="cov5" title="8">uuid, _ := uuid.GenerateUUID() + err = v.storeUUID(uuid, name) + if smslogger.CheckError(err, "Store UUID") != nil </span><span class="cov0" title="0">{ + // Mount was successful at this point. + // Rollback the mount operation since we could not + // store the UUID for the mount. + v.vaultClient.Sys().Unmount(mountPath) + return SecretDomain{}, errors.New("Unable to store Secret Domain UUID. Retry") + }</span> + + <span class="cov5" title="8">return SecretDomain{uuid, name}, nil</span> } // CreateSecret creates a secret mounted on a particular domain name // The secret itself is mounted on a path specified by name -func (v *Vault) CreateSecret(dom string, sec Secret) error <span class="cov6" title="6">{ +func (v *Vault) CreateSecret(dom string, sec Secret) error <span class="cov7" title="18">{ + err := v.checkToken() - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Token Check") != nil </span><span class="cov0" title="0">{ return errors.New("Token check failed") }</span> - <span class="cov6" title="6">dom = v.vaultMount + "/" + dom + <span class="cov7" title="18">dom = v.vaultMountPrefix + "/" + dom // Vault return is empty on successful write // TODO: Check if values is not empty _, err = v.vaultClient.Logical().Write(dom+"/"+sec.Name, sec.Values) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Create Secret") != nil </span><span class="cov0" title="0">{ return errors.New("Unable to create Secret at provided path") }</span> - <span class="cov6" title="6">return nil</span> + <span class="cov7" title="18">return nil</span> } // DeleteSecretDomain deletes a secret domain which translates to // an unmount operation on the given path in Vault -func (v *Vault) DeleteSecretDomain(name string) error <span class="cov2" title="2">{ +func (v *Vault) DeleteSecretDomain(name string) error <span class="cov3" title="3">{ + err := v.checkToken() - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Token Check") != nil </span><span class="cov0" title="0">{ return errors.New("Token Check Failed") }</span> - <span class="cov2" title="2">name = strings.TrimSpace(name) - mountPath := v.vaultMount + "/" + name + <span class="cov3" title="3">name = strings.TrimSpace(name) + mountPath := v.vaultMountPrefix + "/" + name err = v.vaultClient.Sys().Unmount(mountPath) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Delete Domain") != nil </span><span class="cov0" title="0">{ return errors.New("Unable to delete domain specified") }</span> - <span class="cov2" title="2">return nil</span> + <span class="cov3" title="3">return nil</span> } // DeleteSecret deletes a secret mounted on the path provided -func (v *Vault) DeleteSecret(dom string, name string) error <span class="cov6" title="6">{ +func (v *Vault) DeleteSecret(dom string, name string) error <span class="cov5" title="7">{ + err := v.checkToken() - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Token Check") != nil </span><span class="cov0" title="0">{ return errors.New("Token check failed") }</span> - <span class="cov6" title="6">dom = v.vaultMount + "/" + dom + <span class="cov5" title="7">dom = v.vaultMountPrefix + "/" + dom // Vault return is empty on successful delete _, err = v.vaultClient.Logical().Delete(dom + "/" + name) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Delete Secret") != nil </span><span class="cov0" title="0">{ return errors.New("Unable to delete Secret at provided path") }</span> - <span class="cov6" title="6">return nil</span> + <span class="cov5" title="7">return nil</span> } -// initRole is called only once during the service bring up -func (v *Vault) initRole() error <span class="cov4" title="3">{ +// initRole is called only once during SMS bring up +// It initially creates a role and secret id associated with +// that role. Later restarts will use the existing role-id +// and secret-id stored on disk +func (v *Vault) initRole() error <span class="cov10" title="56">{ + + if v.initRoleDone </span><span class="cov9" title="48">{ + return nil + }</span> + // Use the root token once here - v.vaultClient.SetToken(v.vaultToken) + <span class="cov5" title="8">v.vaultClient.SetToken(v.vaultToken) defer v.vaultClient.ClearToken() - rules := `path "sms/*" { capabilities = ["create", "read", "update", "delete", "list"] } + // Check if roleID and secretID has already been created + rID, error := smsauth.ReadFromFile("auth/role") + if error != nil </span><span class="cov5" title="7">{ + smslogger.WriteWarn("Unable to find RoleID. Generating...") + }</span><span class="cov1" title="1"> else { + sID, error := smsauth.ReadFromFile("auth/secret") + if error != nil </span><span class="cov0" title="0">{ + smslogger.WriteWarn("Unable to find secretID. Generating...") + }</span><span class="cov1" title="1"> else { + v.roleID = rID + v.secretID = sID + v.initRoleDone = true + return nil + }</span> + } + + <span class="cov5" title="7">rules := `path "sms/*" { capabilities = ["create", "read", "update", "delete", "list"] } path "sys/mounts/sms*" { capabilities = ["update","delete","create"] }` err := v.vaultClient.Sys().PutPolicy(v.policyName, rules) - if err != nil </span><span class="cov2" title="2">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Creating Policy") != nil </span><span class="cov0" title="0">{ return errors.New("Unable to create policy for approle creation") }</span> - <span class="cov1" title="1">rName := v.vaultMount + "-role" - data := map[string]interface{}{ - "token_ttl": "60m", - "policies": [2]string{"default", v.policyName}, - } - //Check if applrole is mounted - authMounts, err := v.vaultClient.Sys().ListAuth() - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + <span class="cov5" title="7">authMounts, err := v.vaultClient.Sys().ListAuth() + if smslogger.CheckError(err, "Mount Auth Backend") != nil </span><span class="cov0" title="0">{ return errors.New("Unable to get mounted auth backends") }</span> - <span class="cov1" title="1">approleMounted := false - for k, v := range authMounts </span><span class="cov1" title="1">{ - if v.Type == "approle" && k == "approle/" </span><span class="cov1" title="1">{ + <span class="cov5" title="7">approleMounted := false + for k, v := range authMounts </span><span class="cov5" title="7">{ + if v.Type == "approle" && k == "approle/" </span><span class="cov0" title="0">{ approleMounted = true break</span> } } // Mount approle in case its not already mounted - <span class="cov1" title="1">if !approleMounted </span><span class="cov0" title="0">{ + <span class="cov5" title="7">if !approleMounted </span><span class="cov5" title="7">{ v.vaultClient.Sys().EnableAuth("approle", "approle", "") }</span> + <span class="cov5" title="7">rName := v.vaultMountPrefix + "-role" + data := map[string]interface{}{ + "token_ttl": "60m", + "policies": [2]string{"default", v.policyName}, + } + // Create a role-id - <span class="cov1" title="1">v.vaultClient.Logical().Write("auth/approle/role/"+rName, data) + v.vaultClient.Logical().Write("auth/approle/role/"+rName, data) sec, err := v.vaultClient.Logical().Read("auth/approle/role/" + rName + "/role-id") - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Create RoleID") != nil </span><span class="cov0" title="0">{ return errors.New("Unable to create role ID for approle") }</span> - <span class="cov1" title="1">v.roleID = sec.Data["role_id"].(string) + <span class="cov5" title="7">v.roleID = sec.Data["role_id"].(string) // Create a secret-id to go with it sec, err = v.vaultClient.Logical().Write("auth/approle/role/"+rName+"/secret-id", map[string]interface{}{}) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Create SecretID") != nil </span><span class="cov0" title="0">{ return errors.New("Unable to create secret ID for role") }</span> - <span class="cov1" title="1">v.secretID = sec.Data["secret_id"].(string) + <span class="cov5" title="7">v.secretID = sec.Data["secret_id"].(string) v.initRoleDone = true + /* + * Revoke the Root token. + * If a new Root Token is needed, it will need to be created + * using the unseal shards. + */ + err = v.vaultClient.Auth().Token().RevokeSelf(v.vaultToken) + if smslogger.CheckError(err, "Revoke Root Token") != nil </span><span class="cov0" title="0">{ + smslogger.WriteWarn("Unable to Revoke Token") + }</span><span class="cov5" title="7"> else { + // Revoked successfully and clear it + v.vaultToken = "" + }</span> + + // Store the role-id and secret-id + // We will need this if SMS restarts + <span class="cov5" title="7">smsauth.WriteToFile(v.roleID, "auth/role") + smsauth.WriteToFile(v.secretID, "auth/secret") + return nil</span> } // Function checkToken() gets called multiple times to create // temporary tokens -func (v *Vault) checkToken() error <span class="cov10" title="24">{ +func (v *Vault) checkToken() error <span class="cov9" title="54">{ + v.Lock() defer v.Unlock() // Init Role if it is not yet done // Role needs to be created before token can be created - if v.initRoleDone == false </span><span class="cov0" title="0">{ - err := v.initRole() - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) - return errors.New("Unable to initRole in checkToken") - }</span> - } + err := v.initRole() + if err != nil </span><span class="cov0" title="0">{ + smslogger.WriteError(err.Error()) + return errors.New("Unable to initRole in checkToken") + }</span> // Return immediately if token still has life - <span class="cov10" title="24">if v.vaultClient.Token() != "" && - time.Since(v.vaultTempTokenTTL) < time.Minute*50 </span><span class="cov9" title="23">{ + <span class="cov9" title="54">if v.vaultClient.Token() != "" && + time.Since(v.vaultTempTokenTTL) < time.Minute*50 </span><span class="cov9" title="47">{ return nil }</span> // Create a temporary token using our roleID and secretID - <span class="cov1" title="1">out, err := v.vaultClient.Logical().Write("auth/approle/login", + <span class="cov5" title="7">out, err := v.vaultClient.Logical().Write("auth/approle/login", map[string]interface{}{"role_id": v.roleID, "secret_id": v.secretID}) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Create Temp Token") != nil </span><span class="cov0" title="0">{ return errors.New("Unable to create Temporary Token for Role") }</span> - <span class="cov1" title="1">tok, err := out.TokenID() + <span class="cov5" title="7">tok, err := out.TokenID() v.vaultTempTokenTTL = time.Now() v.vaultClient.SetToken(tok) @@ -655,31 +853,53 @@ func (v *Vault) checkToken() error <span class="cov10" title="24">{ // vaultInit() is used to initialize the vault in cases where it is not // initialized. This happens once during intial bring up. -func (v *Vault) initializeVault() error <span class="cov0" title="0">{ - initReq := &vaultapi.InitRequest{ - SecretShares: 5, +func (v *Vault) initializeVault() error <span class="cov2" title="2">{ + + // Check for vault init status and don't exit till it is initialized + for </span><span class="cov2" title="2">{ + init, err := v.vaultClient.Sys().InitStatus() + if smslogger.CheckError(err, "Get Vault Init Status") != nil </span><span class="cov0" title="0">{ + smslogger.WriteInfo("Trying again in 10s...") + time.Sleep(time.Second * 10) + continue</span> + } + // Did not get any error + <span class="cov2" title="2">if init == true </span><span class="cov1" title="1">{ + smslogger.WriteInfo("Vault is already Initialized") + return nil + }</span> + + // init status is false + // break out of loop and finish initialization + <span class="cov1" title="1">smslogger.WriteInfo("Vault is not initialized. Initializing...") + break</span> + } + + // Hardcoded this to 3. We should make this configurable + // in the future + <span class="cov1" title="1">initReq := &vaultapi.InitRequest{ + SecretShares: 3, SecretThreshold: 3, } pbkey, prkey, err := smsauth.GeneratePGPKeyPair() - if err != nil </span><span class="cov0" title="0">{ + + if smslogger.CheckError(err, "Generating PGP Keys") != nil </span><span class="cov0" title="0">{ smslogger.WriteError("Error Generating PGP Keys. Vault Init will not use encryption!") - }</span><span class="cov0" title="0"> else { - initReq.PGPKeys = []string{pbkey, pbkey, pbkey, pbkey, pbkey} + }</span><span class="cov1" title="1"> else { + initReq.PGPKeys = []string{pbkey, pbkey, pbkey} initReq.RootTokenPGPKey = pbkey - v.pgpPub = pbkey - v.pgpPr = prkey }</span> - <span class="cov0" title="0">resp, err := v.vaultClient.Sys().Init(initReq) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + <span class="cov1" title="1">resp, err := v.vaultClient.Sys().Init(initReq) + if smslogger.CheckError(err, "Initialize Vault") != nil </span><span class="cov0" title="0">{ return errors.New("FATAL: Unable to initialize Vault") }</span> - <span class="cov0" title="0">if resp != nil </span><span class="cov0" title="0">{ - v.unsealShards = resp.KeysB64 - v.rootToken = resp.RootToken + <span class="cov1" title="1">if resp != nil </span><span class="cov1" title="1">{ + v.prkey = prkey + v.shards = resp.KeysB64 + v.vaultToken, _ = smsauth.DecryptPGPString(resp.RootToken, prkey) return nil }</span> @@ -708,6 +928,7 @@ package config import ( "encoding/json" "os" + smslogger "sms/log" ) // SMSConfiguration loads up all the values that are used to configure @@ -718,8 +939,10 @@ type SMSConfiguration struct { ServerCert string `json:"servercert"` ServerKey string `json:"serverkey"` - VaultAddress string `json:"vaultaddress"` - VaultToken string `json:"vaulttoken"` + BackendAddress string `json:"smsdbaddress"` + VaultToken string `json:"vaulttoken"` + DisableTLS bool `json:"disable_tls"` + BackendAddressEnvVariable string `json:"smsdburlenv"` } // SMSConfig is the structure that stores the configuration @@ -734,12 +957,19 @@ func ReadConfigFile(file string) (*SMSConfiguration, error) <span class="cov10" }</span> <span class="cov6" title="2">defer f.Close() - SMSConfig = &SMSConfiguration{} + // Default behaviour is to enable TLS + SMSConfig = &SMSConfiguration{DisableTLS: false} decoder := json.NewDecoder(f) err = decoder.Decode(SMSConfig) if err != nil </span><span class="cov0" title="0">{ return nil, err }</span> + + <span class="cov6" title="2">if SMSConfig.BackendAddress == "" && SMSConfig.BackendAddressEnvVariable != "" </span><span class="cov0" title="0">{ + // Get the value from ENV variable + smslogger.WriteInfo("Using Environment Variable: " + SMSConfig.BackendAddressEnvVariable) + SMSConfig.BackendAddress = os.Getenv(SMSConfig.BackendAddressEnvVariable) + }</span> } <span class="cov6" title="2">return SMSConfig, nil</span> @@ -785,29 +1015,24 @@ func (h handler) createSecretDomainHandler(w http.ResponseWriter, r *http.Reques var d smsbackend.SecretDomain err := json.NewDecoder(r.Body).Decode(&d) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil </span><span class="cov0" title="0">{ http.Error(w, err.Error(), http.StatusBadRequest) return }</span> <span class="cov6" title="3">dom, err := h.secretBackend.CreateSecretDomain(d.Name) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil </span><span class="cov0" title="0">{ http.Error(w, err.Error(), http.StatusInternalServerError) return }</span> - <span class="cov6" title="3">jdata, err := json.Marshal(dom) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + <span class="cov6" title="3">w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + err = json.NewEncoder(w).Encode(dom) + if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil </span><span class="cov0" title="0">{ http.Error(w, err.Error(), http.StatusInternalServerError) return }</span> - - <span class="cov6" title="3">w.Header().Set("Content-Type", "application/json") - w.WriteHeader(http.StatusCreated) - w.Write(jdata)</span> } // deleteSecretDomainHandler deletes a secret domain with the name provided @@ -816,8 +1041,7 @@ func (h handler) deleteSecretDomainHandler(w http.ResponseWriter, r *http.Reques domName := vars["domName"] err := h.secretBackend.DeleteSecretDomain(domName) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "DeleteSecretDomainHandler") != nil </span><span class="cov0" title="0">{ http.Error(w, err.Error(), http.StatusInternalServerError) return }</span> @@ -834,15 +1058,13 @@ func (h handler) createSecretHandler(w http.ResponseWriter, r *http.Request) <sp // Get secrets to be stored from body var b smsbackend.Secret err := json.NewDecoder(r.Body).Decode(&b) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "CreateSecretHandler") != nil </span><span class="cov0" title="0">{ http.Error(w, err.Error(), http.StatusBadRequest) return }</span> <span class="cov10" title="7">err = h.secretBackend.CreateSecret(domName, b) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "CreateSecretHandler") != nil </span><span class="cov0" title="0">{ http.Error(w, err.Error(), http.StatusInternalServerError) return }</span> @@ -857,21 +1079,17 @@ func (h handler) getSecretHandler(w http.ResponseWriter, r *http.Request) <span secName := vars["secretName"] sec, err := h.secretBackend.GetSecret(domName, secName) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "GetSecretHandler") != nil </span><span class="cov0" title="0">{ http.Error(w, err.Error(), http.StatusInternalServerError) return }</span> - <span class="cov10" title="7">jdata, err := json.Marshal(sec) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + <span class="cov10" title="7">w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(sec) + if smslogger.CheckError(err, "GetSecretHandler") != nil </span><span class="cov0" title="0">{ http.Error(w, err.Error(), http.StatusInternalServerError) return }</span> - - <span class="cov10" title="7">w.Header().Set("Content-Type", "application/json") - w.Write(jdata)</span> } // listSecretHandler handles listing all secrets under a particular domain name @@ -880,8 +1098,7 @@ func (h handler) listSecretHandler(w http.ResponseWriter, r *http.Request) <span domName := vars["domName"] secList, err := h.secretBackend.ListSecret(domName) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "ListSecretHandler") != nil </span><span class="cov0" title="0">{ http.Error(w, err.Error(), http.StatusInternalServerError) return }</span> @@ -893,15 +1110,12 @@ func (h handler) listSecretHandler(w http.ResponseWriter, r *http.Request) <span secList, } - jdata, err := json.Marshal(retStruct) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(retStruct) + if smslogger.CheckError(err, "ListSecretHandler") != nil </span><span class="cov0" title="0">{ http.Error(w, err.Error(), http.StatusInternalServerError) return }</span> - - <span class="cov6" title="3">w.Header().Set("Content-Type", "application/json") - w.Write(jdata)</span> } // deleteSecretHandler handles deleting a secret by given domain name and secret name @@ -911,37 +1125,34 @@ func (h handler) deleteSecretHandler(w http.ResponseWriter, r *http.Request) <sp secName := vars["secretName"] err := h.secretBackend.DeleteSecret(domName, secName) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "DeleteSecretHandler") != nil </span><span class="cov0" title="0">{ http.Error(w, err.Error(), http.StatusInternalServerError) return }</span> -} -// struct that tracks various status items for SMS and backend -type backendStatus struct { - Seal bool `json:"sealstatus"` + <span class="cov10" title="7">w.WriteHeader(http.StatusNoContent)</span> } // statusHandler returns information related to SMS and SMS backend services -func (h handler) statusHandler(w http.ResponseWriter, r *http.Request) <span class="cov6" title="3">{ +func (h handler) statusHandler(w http.ResponseWriter, r *http.Request) <span class="cov7" title="4">{ s, err := h.secretBackend.GetStatus() - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "StatusHandler") != nil </span><span class="cov0" title="0">{ http.Error(w, err.Error(), http.StatusInternalServerError) return }</span> - <span class="cov6" title="3">status := backendStatus{Seal: s} - jdata, err := json.Marshal(status) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + <span class="cov7" title="4">status := struct { + Seal bool `json:"sealstatus"` + }{ + s, + } + + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(status) + if smslogger.CheckError(err, "StatusHandler") != nil </span><span class="cov0" title="0">{ http.Error(w, err.Error(), http.StatusInternalServerError) return }</span> - - <span class="cov6" title="3">w.Header().Set("Content-Type", "application/json") - w.Write(jdata)</span> } // loginHandler handles login via password and username @@ -961,15 +1172,53 @@ func (h handler) unsealHandler(w http.ResponseWriter, r *http.Request) <span cla decoder := json.NewDecoder(r.Body) decoder.DisallowUnknownFields() err := decoder.Decode(&inp) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "UnsealHandler") != nil </span><span class="cov0" title="0">{ http.Error(w, "Bad input JSON", http.StatusBadRequest) return }</span> <span class="cov0" title="0">err = h.secretBackend.Unseal(inp.UnsealShard) - if err != nil </span><span class="cov0" title="0">{ - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "UnsealHandler") != nil </span><span class="cov0" title="0">{ + http.Error(w, err.Error(), http.StatusInternalServerError) + return + }</span> +} + +// registerHandler allows the quorum clients to register with SMS +// with their PGP public keys that are then used by sms for backend +// initialization +func (h handler) registerHandler(w http.ResponseWriter, r *http.Request) <span class="cov1" title="1">{ + // Get shards to be used for unseal + type registerStruct struct { + PGPKey string `json:"pgpkey"` + QuorumID string `json:"quorumid"` + } + + var inp registerStruct + decoder := json.NewDecoder(r.Body) + decoder.DisallowUnknownFields() + err := decoder.Decode(&inp) + if smslogger.CheckError(err, "RegisterHandler") != nil </span><span class="cov0" title="0">{ + http.Error(w, "Bad input JSON", http.StatusBadRequest) + return + }</span> + + <span class="cov1" title="1">sh, err := h.secretBackend.RegisterQuorum(inp.PGPKey) + if smslogger.CheckError(err, "RegisterHandler") != nil </span><span class="cov0" title="0">{ + http.Error(w, err.Error(), http.StatusInternalServerError) + return + }</span> + + // Creating a struct for return data + <span class="cov1" title="1">shStruct := struct { + Shard string `json:"shard"` + }{ + sh, + } + + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(shStruct) + if smslogger.CheckError(err, "RegisterHandler") != nil </span><span class="cov0" title="0">{ http.Error(w, err.Error(), http.StatusInternalServerError) return }</span> @@ -987,8 +1236,9 @@ func CreateRouter(b smsbackend.SecretBackend) http.Handler <span class="cov4" ti // Initialization APIs which will be used by quorum client // to unseal and to provide root token to sms service - router.HandleFunc("/v1/sms/status", h.statusHandler).Methods("GET") - router.HandleFunc("/v1/sms/unseal", h.unsealHandler).Methods("POST") + router.HandleFunc("/v1/sms/quorum/status", h.statusHandler).Methods("GET") + router.HandleFunc("/v1/sms/quorum/unseal", h.unsealHandler).Methods("POST") + router.HandleFunc("/v1/sms/quorum/register", h.registerHandler).Methods("POST") router.HandleFunc("/v1/sms/domain", h.createSecretDomainHandler).Methods("POST") router.HandleFunc("/v1/sms/domain/{domName}", h.deleteSecretDomainHandler).Methods("DELETE") @@ -1021,53 +1271,85 @@ func CreateRouter(b smsbackend.SecretBackend) http.Handler <span class="cov4" ti package log import ( + "fmt" "log" "os" ) -var errLogger *log.Logger -var warnLogger *log.Logger -var infoLogger *log.Logger +var errL, warnL, infoL *log.Logger +var stdErr, stdWarn, stdInfo *log.Logger // Init will be called by sms.go before any other packages use it -func Init(filePath string) <span class="cov8" title="1">{ - f, err := os.Create(filePath) +func Init(filePath string) <span class="cov1" title="1">{ + + stdErr = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags) + stdWarn = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags) + stdInfo = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags) + + if filePath == "" </span><span class="cov0" title="0">{ + // We will just to std streams + return + }</span> + + <span class="cov1" title="1">f, err := os.Create(filePath) if err != nil </span><span class="cov0" title="0">{ - log.Println("Unable to create a log file") - log.Println(err) - errLogger = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags) - warnLogger = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags) - infoLogger = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags) - }</span><span class="cov8" title="1"> else { - errLogger = log.New(f, "ERROR: ", log.Lshortfile|log.LstdFlags) - warnLogger = log.New(f, "WARNING: ", log.Lshortfile|log.LstdFlags) - infoLogger = log.New(f, "INFO: ", log.Lshortfile|log.LstdFlags) + stdErr.Println("Unable to create log file: " + err.Error()) + return }</span> + + <span class="cov1" title="1">errL = log.New(f, "ERROR: ", log.Lshortfile|log.LstdFlags) + warnL = log.New(f, "WARNING: ", log.Lshortfile|log.LstdFlags) + infoL = log.New(f, "INFO: ", log.Lshortfile|log.LstdFlags)</span> } // WriteError writes output to the writer we have -// defined durint its creation with ERROR prefix +// defined during its creation with ERROR prefix func WriteError(msg string) <span class="cov0" title="0">{ - if errLogger != nil </span><span class="cov0" title="0">{ - errLogger.Println(msg) + if errL != nil </span><span class="cov0" title="0">{ + errL.Output(2, fmt.Sprintln(msg)) + }</span> + <span class="cov0" title="0">if stdErr != nil </span><span class="cov0" title="0">{ + stdErr.Output(2, fmt.Sprintln(msg)) }</span> } // WriteWarn writes output to the writer we have -// defined durint its creation with WARNING prefix +// defined during its creation with WARNING prefix func WriteWarn(msg string) <span class="cov0" title="0">{ - if warnLogger != nil </span><span class="cov0" title="0">{ - warnLogger.Println(msg) + if warnL != nil </span><span class="cov0" title="0">{ + warnL.Output(2, fmt.Sprintln(msg)) + }</span> + <span class="cov0" title="0">if stdWarn != nil </span><span class="cov0" title="0">{ + stdWarn.Output(2, fmt.Sprintln(msg)) }</span> } // WriteInfo writes output to the writer we have -// defined durint its creation with INFO prefix -func WriteInfo(msg string) <span class="cov0" title="0">{ - if infoLogger != nil </span><span class="cov0" title="0">{ - infoLogger.Println(msg) +// defined during its creation with INFO prefix +func WriteInfo(msg string) <span class="cov1" title="1">{ + if infoL != nil </span><span class="cov1" title="1">{ + infoL.Output(2, fmt.Sprintln(msg)) + }</span> + <span class="cov1" title="1">if stdInfo != nil </span><span class="cov1" title="1">{ + stdInfo.Output(2, fmt.Sprintln(msg)) }</span> } + +//CheckError is a helper function to reduce +//repitition of error checkign blocks of code +func CheckError(err error, topic string) error <span class="cov10" title="116">{ + if err != nil </span><span class="cov1" title="1">{ + msg := topic + ": " + err.Error() + if errL != nil </span><span class="cov1" title="1">{ + errL.Output(2, fmt.Sprintln(msg)) + }</span> + <span class="cov1" title="1">if stdErr != nil </span><span class="cov1" title="1">{ + stdErr.Output(2, fmt.Sprintln(msg)) + }</span> + <span class="cov1" title="1">return err</span> + } + <span class="cov9" title="115">return nil</span> +} </pre> <pre class="file" id="file6" style="display: none">/* @@ -1119,16 +1401,9 @@ func main() <span class="cov8" title="1">{ <span class="cov8" title="1">httpRouter := smshandler.CreateRouter(backendImpl) - // TODO: Use CA certificate from AAF - tlsConfig, err := smsauth.GetTLSConfig(smsConf.CAFile) - if err != nil </span><span class="cov0" title="0">{ - log.Fatal(err) - }</span> - - <span class="cov8" title="1">httpServer := &http.Server{ - Handler: httpRouter, - Addr: ":10443", - TLSConfig: tlsConfig, + httpServer := &http.Server{ + Handler: httpRouter, + Addr: ":10443", } // Listener for SIGINT so that it returns cleanly @@ -1141,8 +1416,22 @@ func main() <span class="cov8" title="1">{ close(connectionsClose) }</span>() - <span class="cov8" title="1">err = httpServer.ListenAndServeTLS(smsConf.ServerCert, smsConf.ServerKey) - if err != nil && err != http.ErrServerClosed </span><span class="cov0" title="0">{ + // Start in TLS mode by default + <span class="cov8" title="1">if smsConf.DisableTLS == true </span><span class="cov0" title="0">{ + smslogger.WriteWarn("TLS is Disabled") + err = httpServer.ListenAndServe() + }</span><span class="cov8" title="1"> else { + // TODO: Use CA certificate from AAF + tlsConfig, err := smsauth.GetTLSConfig(smsConf.CAFile) + if err != nil </span><span class="cov0" title="0">{ + log.Fatal(err) + }</span> + + <span class="cov8" title="1">httpServer.TLSConfig = tlsConfig + err = httpServer.ListenAndServeTLS(smsConf.ServerCert, smsConf.ServerKey)</span> + } + + <span class="cov8" title="1">if err != nil && err != http.ErrServerClosed </span><span class="cov0" title="0">{ log.Fatal(err) }</span> diff --git a/sms-service/src/sms/coverage.md b/sms-service/doc/coverage.md index 6168342..6168342 100644 --- a/sms-service/src/sms/coverage.md +++ b/sms-service/doc/coverage.md diff --git a/sms-service/src/quorumclient/quorumclient.go b/sms-service/src/quorumclient/quorumclient.go index 05fc967..dfa1a26 100644 --- a/sms-service/src/quorumclient/quorumclient.go +++ b/sms-service/src/quorumclient/quorumclient.go @@ -37,13 +37,13 @@ func loadPGPKeys(prKeyPath string, pbKeyPath string) (string, string, error) { var pbkey, prkey string generated := false prkey, err := smsauth.ReadFromFile(prKeyPath) - if err != nil { - smslogger.WriteWarn("No Private Key found. Generating...") + if smslogger.CheckError(err, "LoadPGP Private Key") != nil { + smslogger.WriteInfo("No Private Key found. Generating...") pbkey, prkey, _ = smsauth.GeneratePGPKeyPair() generated = true } else { pbkey, err = smsauth.ReadFromFile(pbKeyPath) - if err != nil { + if smslogger.CheckError(err, "LoadPGP Public Key") != nil { smslogger.WriteWarn("No Public Key found. Generating...") pbkey, prkey, _ = smsauth.GeneratePGPKeyPair() generated = true @@ -70,7 +70,7 @@ func main() { prKeyPath := filepath.Join("auth", podName, "prkey") shardPath := filepath.Join("auth", podName, "shard") - smslogger.Init("") + smslogger.Init("quorum.log") smslogger.WriteInfo("Starting Log for Quorum Client") /* @@ -80,7 +80,7 @@ func main() { In Kubernetes, pod restarts will also change the hostname */ myID, err := smsauth.ReadFromFile(idFilePath) - if err != nil { + if smslogger.CheckError(err, "Read ID") != nil { smslogger.WriteWarn("Unable to find an ID for this client. Generating...") myID, _ = uuid.GenerateUUID() smsauth.WriteToFile(myID, idFilePath) @@ -93,7 +93,7 @@ func main() { */ registrationDone := true myShard, err := smsauth.ReadFromFile(shardPath) - if err != nil { + if smslogger.CheckError(err, "Read Shard") != nil { smslogger.WriteWarn("Unable to find a shard file. Registering with SMS...") registrationDone = false } @@ -160,8 +160,7 @@ func main() { //URL and Port is configured in config file response, err := client.Get(cfg.BackEndURL + "/v1/sms/quorum/status") - if err != nil { - smslogger.WriteError("Unable to connect to SMS. Retrying...") + if smslogger.CheckError(err, "Connect to SMS") != nil { continue } @@ -178,8 +177,7 @@ func main() { if !registrationDone { body := strings.NewReader(`{"pgpkey":"` + pbkey + `","quorumid":"` + myID + `"}`) res, err := client.Post(cfg.BackEndURL+"/v1/sms/quorum/register", "application/json", body) - if err != nil { - smslogger.WriteError("Ran into error during registration. Retrying...") + if smslogger.CheckError(err, "Register with SMS") != nil { continue } registrationDone = true @@ -195,8 +193,8 @@ func main() { body := strings.NewReader(`{"unsealshard":"` + decShard + `"}`) //URL and PORT is configured via config file response, err = client.Post(cfg.BackEndURL+"/v1/sms/quorum/unseal", "application/json", body) - if err != nil { - smslogger.WriteError("Error unsealing vault. Retrying... " + err.Error()) + if smslogger.CheckError(err, "Unsealing Vault") != nil { + continue } } } diff --git a/sms-service/src/sms/Gopkg.lock b/sms-service/src/sms/Gopkg.lock index c7684c7..2c09256 100644 --- a/sms-service/src/sms/Gopkg.lock +++ b/sms-service/src/sms/Gopkg.lock @@ -477,6 +477,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "d19e17a023506ab731b0f26c6fcfebe581d4d5194af094aecea5e72daddd3ead" + inputs-digest = "8280cde72a3ab78ad00d13c192de5920d188f3052f45884563896cab659469f9" solver-name = "gps-cdcl" solver-version = 1 diff --git a/sms-service/src/sms/auth/auth.go b/sms-service/src/sms/auth/auth.go index 7172505..038e31d 100644 --- a/sms-service/src/sms/auth/auth.go +++ b/sms-service/src/sms/auth/auth.go @@ -29,39 +29,27 @@ import ( smslogger "sms/log" ) -var tlsConfig *tls.Config - -func checkError(err error, topic string) error { - if err != nil { - smslogger.WriteError(topic + ": " + err.Error()) - return err - } - - return nil -} - // GetTLSConfig initializes a tlsConfig using the CA's certificate // This config is then used to enable the server for mutual TLS func GetTLSConfig(caCertFile string) (*tls.Config, error) { + // Initialize tlsConfig once - if tlsConfig == nil { - caCert, err := ioutil.ReadFile(caCertFile) + caCert, err := ioutil.ReadFile(caCertFile) - if err != nil { - return nil, err - } + if err != nil { + return nil, err + } - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) - tlsConfig = &tls.Config{ - // Change to RequireAndVerify once we have mandatory certs - ClientAuth: tls.VerifyClientCertIfGiven, - ClientCAs: caCertPool, - MinVersion: tls.VersionTLS12, - } - tlsConfig.BuildNameToCertificate() + tlsConfig := &tls.Config{ + // Change to RequireAndVerify once we have mandatory certs + ClientAuth: tls.VerifyClientCertIfGiven, + ClientCAs: caCertPool, + MinVersion: tls.VersionTLS12, } + tlsConfig.BuildNameToCertificate() return tlsConfig, nil } @@ -70,22 +58,21 @@ func GetTLSConfig(caCertFile string) (*tls.Config, error) { // A base64 encoded form of the public part of the entity // A base64 encoded form of the private key func GeneratePGPKeyPair() (string, string, error) { + var entity *openpgp.Entity config := &packet.Config{ DefaultHash: crypto.SHA256, } entity, err := openpgp.NewEntity("aaf.sms.init", "PGP Key for unsealing", "", config) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Create Entity") != nil { return "", "", err } // Sign the identity in the entity for _, id := range entity.Identities { err = id.SelfSignature.SignUserId(id.UserId.Id, entity.PrimaryKey, entity.PrivateKey, nil) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Sign Entity") != nil { return "", "", err } } @@ -93,8 +80,7 @@ func GeneratePGPKeyPair() (string, string, error) { // Sign the subkey in the entity for _, subkey := range entity.Subkeys { err := subkey.Sig.SignKey(subkey.PublicKey, entity.PrivateKey, nil) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Sign Subkey") != nil { return "", "", err } } @@ -113,32 +99,33 @@ func GeneratePGPKeyPair() (string, string, error) { // EncryptPGPString takes data and a public key and encrypts using that // public key func EncryptPGPString(data string, pbKey string) (string, error) { + pbKeyBytes, err := base64.StdEncoding.DecodeString(pbKey) - if checkError(err, "Decoding Base64 Public Key") != nil { + if smslogger.CheckError(err, "Decoding Base64 Public Key") != nil { return "", err } dataBytes := []byte(data) pbEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(pbKeyBytes))) - if checkError(err, "Reading entity from PGP key") != nil { + if smslogger.CheckError(err, "Reading entity from PGP key") != nil { return "", err } // encrypt string buf := new(bytes.Buffer) out, err := openpgp.Encrypt(buf, []*openpgp.Entity{pbEntity}, nil, nil, nil) - if checkError(err, "Creating Encryption Pipe") != nil { + if smslogger.CheckError(err, "Creating Encryption Pipe") != nil { return "", err } _, err = out.Write(dataBytes) - if checkError(err, "Writing to Encryption Pipe") != nil { + if smslogger.CheckError(err, "Writing to Encryption Pipe") != nil { return "", err } err = out.Close() - if checkError(err, "Closing Encryption Pipe") != nil { + if smslogger.CheckError(err, "Closing Encryption Pipe") != nil { return "", err } @@ -149,29 +136,26 @@ func EncryptPGPString(data string, pbKey string) (string, error) { // DecryptPGPString decrypts a PGP encoded input string and returns // a base64 representation of the decoded string func DecryptPGPString(data string, prKey string) (string, error) { + // Convert private key to bytes from base64 prKeyBytes, err := base64.StdEncoding.DecodeString(prKey) - if err != nil { - smslogger.WriteError("Error Decoding base64 private key: " + err.Error()) + if smslogger.CheckError(err, "Decoding Base64 Private Key") != nil { return "", err } dataBytes, err := base64.StdEncoding.DecodeString(data) - if err != nil { - smslogger.WriteError("Error Decoding base64 data: " + err.Error()) + if smslogger.CheckError(err, "Decoding base64 data") != nil { return "", err } prEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(prKeyBytes))) - if err != nil { - smslogger.WriteError("Error reading entity from PGP key: " + err.Error()) + if smslogger.CheckError(err, "Read Entity") != nil { return "", err } prEntityList := &openpgp.EntityList{prEntity} message, err := openpgp.ReadMessage(bytes.NewBuffer(dataBytes), prEntityList, nil, nil) - if err != nil { - smslogger.WriteError("Error Decrypting message: " + err.Error()) + if smslogger.CheckError(err, "Decrypting Message") != nil { return "", err } @@ -186,13 +170,10 @@ func DecryptPGPString(data string, prKey string) (string, error) { func ReadFromFile(fileName string) (string, error) { data, err := ioutil.ReadFile(fileName) - if err != nil { - smslogger.WriteError(err.Error()) - smslogger.WriteError("Cannot read file: " + fileName) + if smslogger.CheckError(err, "Read from file") != nil { return "", err } return string(data), nil - } // WriteToFile writes a PGP key into a file. @@ -200,11 +181,8 @@ func ReadFromFile(fileName string) (string, error) { func WriteToFile(data string, fileName string) error { err := ioutil.WriteFile(fileName, []byte(data), 0600) - if err != nil { - smslogger.WriteError(err.Error()) - smslogger.WriteError("Cannot write to file: " + fileName) + if smslogger.CheckError(err, "Write to file") != nil { return err } return nil - } diff --git a/sms-service/src/sms/backend/backend.go b/sms-service/src/sms/backend/backend.go index c137636..d7662ef 100644 --- a/sms-service/src/sms/backend/backend.go +++ b/sms-service/src/sms/backend/backend.go @@ -60,8 +60,7 @@ func InitSecretBackend() (SecretBackend, error) { } err := backendImpl.Init() - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "InitSecretBackend") != nil { return nil, err } diff --git a/sms-service/src/sms/backend/vault.go b/sms-service/src/sms/backend/vault.go index e26baff..7fee097 100644 --- a/sms-service/src/sms/backend/vault.go +++ b/sms-service/src/sms/backend/vault.go @@ -56,9 +56,8 @@ func (v *Vault) initVaultClient() error { vaultCFG := vaultapi.DefaultConfig() vaultCFG.Address = v.vaultAddress client, err := vaultapi.NewClient(vaultCFG) - if err != nil { - smslogger.WriteError(err.Error()) - return errors.New("Unable to create new vault client") + if smslogger.CheckError(err, "Create new vault client") != nil { + return err } v.initRoleDone = false @@ -69,7 +68,6 @@ func (v *Vault) initVaultClient() error { v.internalDomainMounted = false v.prkey = "" return nil - } // Init will initialize the vault connection @@ -84,8 +82,7 @@ func (v *Vault) Init() error { v.initializeVault() err := v.initRole() - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "InitRole First Attempt") != nil { smslogger.WriteInfo("InitRole will try again later") } @@ -94,10 +91,10 @@ func (v *Vault) Init() error { // GetStatus returns the current seal status of vault func (v *Vault) GetStatus() (bool, error) { + sys := v.vaultClient.Sys() sealStatus, err := sys.SealStatus() - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Getting Status") != nil { return false, errors.New("Error getting status") } @@ -112,7 +109,7 @@ func (v *Vault) RegisterQuorum(pgpkey string) (string, error) { defer v.Unlock() if v.shards == nil { - smslogger.WriteError("Invalid operation") + smslogger.WriteError("Invalid operation in RegisterQuorum") return "", errors.New("Invalid operation") } // Pop the slice @@ -133,10 +130,10 @@ func (v *Vault) RegisterQuorum(pgpkey string) (string, error) { // Unseal is a passthrough API that allows any // unseal or initialization processes for the backend func (v *Vault) Unseal(shard string) error { + sys := v.vaultClient.Sys() _, err := sys.Unseal(shard) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Unseal Operation") != nil { return errors.New("Unable to execute unseal operation with specified shard") } @@ -147,17 +144,16 @@ func (v *Vault) Unseal(shard string) error { // The secret itself is referenced via its name which translates to // a mount path in vault func (v *Vault) GetSecret(dom string, name string) (Secret, error) { + err := v.checkToken() - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Tocken Check") != nil { return Secret{}, errors.New("Token check failed") } dom = v.vaultMountPrefix + "/" + dom sec, err := v.vaultClient.Logical().Read(dom + "/" + name) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Read Secret") != nil { return Secret{}, errors.New("Unable to read Secret at provided path") } @@ -173,17 +169,16 @@ func (v *Vault) GetSecret(dom string, name string) (Secret, error) { // ListSecret returns a list of secret names on a particular domain // The values of the secret are not returned func (v *Vault) ListSecret(dom string) ([]string, error) { + err := v.checkToken() - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Token Check") != nil { return nil, errors.New("Token check failed") } dom = v.vaultMountPrefix + "/" + dom sec, err := v.vaultClient.Logical().List(dom) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Read Secret") != nil { return nil, errors.New("Unable to read Secret at provided path") } @@ -209,6 +204,7 @@ func (v *Vault) ListSecret(dom string) ([]string, error) { // Mounts the internal Domain if its not already mounted func (v *Vault) mountInternalDomain(name string) error { + if v.internalDomainMounted { return nil } @@ -224,14 +220,13 @@ func (v *Vault) mountInternalDomain(name string) error { } err := v.vaultClient.Sys().Mount(mountPath, mountInput) - if err != nil { + if smslogger.CheckError(err, "Mount internal Domain") != nil { if strings.Contains(err.Error(), "existing mount") { // It is already mounted v.internalDomainMounted = true return nil } // Ran into some other error mounting it. - smslogger.WriteError(err.Error()) return errors.New("Unable to mount internal Domain") } @@ -242,16 +237,15 @@ func (v *Vault) mountInternalDomain(name string) error { // Stores the UUID created for secretdomain in vault // under v.vaultMountPrefix / smsinternal domain func (v *Vault) storeUUID(uuid string, name string) error { + // Check if token is still valid err := v.checkToken() - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Token Check") != nil { return errors.New("Token Check failed") } err = v.mountInternalDomain(v.internalDomain) - if err != nil { - smslogger.WriteError("Could not mount internal domain") + if smslogger.CheckError(err, "Mount Internal Domain") != nil { return err } @@ -263,8 +257,7 @@ func (v *Vault) storeUUID(uuid string, name string) error { } err = v.CreateSecret(v.internalDomain, secret) - if err != nil { - smslogger.WriteError("Unable to write UUID to internal domain") + if smslogger.CheckError(err, "Write UUID to domain") != nil { return err } @@ -273,10 +266,10 @@ func (v *Vault) storeUUID(uuid string, name string) error { // CreateSecretDomain mounts the kv backend on a path with the given name func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) { + // Check if token is still valid err := v.checkToken() - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Token Check") != nil { return SecretDomain{}, errors.New("Token Check failed") } @@ -291,14 +284,13 @@ func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) { } err = v.vaultClient.Sys().Mount(mountPath, mountInput) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Create Domain") != nil { return SecretDomain{}, errors.New("Unable to create Secret Domain") } uuid, _ := uuid.GenerateUUID() err = v.storeUUID(uuid, name) - if err != nil { + if smslogger.CheckError(err, "Store UUID") != nil { // Mount was successful at this point. // Rollback the mount operation since we could not // store the UUID for the mount. @@ -312,9 +304,9 @@ func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) { // CreateSecret creates a secret mounted on a particular domain name // The secret itself is mounted on a path specified by name func (v *Vault) CreateSecret(dom string, sec Secret) error { + err := v.checkToken() - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Token Check") != nil { return errors.New("Token check failed") } @@ -323,8 +315,7 @@ func (v *Vault) CreateSecret(dom string, sec Secret) error { // Vault return is empty on successful write // TODO: Check if values is not empty _, err = v.vaultClient.Logical().Write(dom+"/"+sec.Name, sec.Values) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Create Secret") != nil { return errors.New("Unable to create Secret at provided path") } @@ -334,9 +325,9 @@ func (v *Vault) CreateSecret(dom string, sec Secret) error { // DeleteSecretDomain deletes a secret domain which translates to // an unmount operation on the given path in Vault func (v *Vault) DeleteSecretDomain(name string) error { + err := v.checkToken() - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Token Check") != nil { return errors.New("Token Check Failed") } @@ -344,8 +335,7 @@ func (v *Vault) DeleteSecretDomain(name string) error { mountPath := v.vaultMountPrefix + "/" + name err = v.vaultClient.Sys().Unmount(mountPath) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Delete Domain") != nil { return errors.New("Unable to delete domain specified") } @@ -356,8 +346,7 @@ func (v *Vault) DeleteSecretDomain(name string) error { func (v *Vault) DeleteSecret(dom string, name string) error { err := v.checkToken() - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Token Check") != nil { return errors.New("Token check failed") } @@ -365,8 +354,7 @@ func (v *Vault) DeleteSecret(dom string, name string) error { // Vault return is empty on successful delete _, err = v.vaultClient.Logical().Delete(dom + "/" + name) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Delete Secret") != nil { return errors.New("Unable to delete Secret at provided path") } @@ -406,15 +394,13 @@ func (v *Vault) initRole() error { rules := `path "sms/*" { capabilities = ["create", "read", "update", "delete", "list"] } path "sys/mounts/sms*" { capabilities = ["update","delete","create"] }` err := v.vaultClient.Sys().PutPolicy(v.policyName, rules) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Creating Policy") != nil { return errors.New("Unable to create policy for approle creation") } //Check if applrole is mounted authMounts, err := v.vaultClient.Sys().ListAuth() - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Mount Auth Backend") != nil { return errors.New("Unable to get mounted auth backends") } @@ -440,8 +426,7 @@ func (v *Vault) initRole() error { // Create a role-id v.vaultClient.Logical().Write("auth/approle/role/"+rName, data) sec, err := v.vaultClient.Logical().Read("auth/approle/role/" + rName + "/role-id") - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Create RoleID") != nil { return errors.New("Unable to create role ID for approle") } v.roleID = sec.Data["role_id"].(string) @@ -449,8 +434,7 @@ func (v *Vault) initRole() error { // Create a secret-id to go with it sec, err = v.vaultClient.Logical().Write("auth/approle/role/"+rName+"/secret-id", map[string]interface{}{}) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Create SecretID") != nil { return errors.New("Unable to create secret ID for role") } @@ -462,8 +446,7 @@ func (v *Vault) initRole() error { * using the unseal shards. */ err = v.vaultClient.Auth().Token().RevokeSelf(v.vaultToken) - if err != nil { - smslogger.WriteWarn(err.Error()) + if smslogger.CheckError(err, "Revoke Root Token") != nil { smslogger.WriteWarn("Unable to Revoke Token") } else { // Revoked successfully and clear it @@ -481,6 +464,7 @@ func (v *Vault) initRole() error { // Function checkToken() gets called multiple times to create // temporary tokens func (v *Vault) checkToken() error { + v.Lock() defer v.Unlock() @@ -501,8 +485,7 @@ func (v *Vault) checkToken() error { // Create a temporary token using our roleID and secretID out, err := v.vaultClient.Logical().Write("auth/approle/login", map[string]interface{}{"role_id": v.roleID, "secret_id": v.secretID}) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Create Temp Token") != nil { return errors.New("Unable to create Temporary Token for Role") } @@ -516,11 +499,12 @@ func (v *Vault) checkToken() error { // vaultInit() is used to initialize the vault in cases where it is not // initialized. This happens once during intial bring up. func (v *Vault) initializeVault() error { + // Check for vault init status and don't exit till it is initialized for { init, err := v.vaultClient.Sys().InitStatus() - if err != nil { - smslogger.WriteError("Unable to get initStatus, trying again in 10s: " + err.Error()) + if smslogger.CheckError(err, "Get Vault Init Status") != nil { + smslogger.WriteInfo("Trying again in 10s...") time.Sleep(time.Second * 10) continue } @@ -545,7 +529,7 @@ func (v *Vault) initializeVault() error { pbkey, prkey, err := smsauth.GeneratePGPKeyPair() - if err != nil { + if smslogger.CheckError(err, "Generating PGP Keys") != nil { smslogger.WriteError("Error Generating PGP Keys. Vault Init will not use encryption!") } else { initReq.PGPKeys = []string{pbkey, pbkey, pbkey} @@ -553,8 +537,7 @@ func (v *Vault) initializeVault() error { } resp, err := v.vaultClient.Sys().Init(initReq) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Initialize Vault") != nil { return errors.New("FATAL: Unable to initialize Vault") } diff --git a/sms-service/src/sms/backend/vault_test.go b/sms-service/src/sms/backend/vault_test.go index 484c395..4862665 100644 --- a/sms-service/src/sms/backend/vault_test.go +++ b/sms-service/src/sms/backend/vault_test.go @@ -17,9 +17,11 @@ package backend import ( + vaultapi "github.com/hashicorp/vault/api" credAppRole "github.com/hashicorp/vault/builtin/credential/approle" vaulthttp "github.com/hashicorp/vault/http" vaultlogical "github.com/hashicorp/vault/logical" + vaultinmem "github.com/hashicorp/vault/physical/inmem" vaulttesting "github.com/hashicorp/vault/vault" "reflect" smslog "sms/log" @@ -229,3 +231,39 @@ func TestDeleteSecret(t *testing.T) { t.Fatal("DeleteSecret: Error Creating secret") } } + +func TestInitializeVault(t *testing.T) { + + inm, err := vaultinmem.NewInmem(nil, nil) + if err != nil { + t.Fatal(err) + } + + core, err := vaulttesting.NewCore(&vaulttesting.CoreConfig{ + DisableMlock: true, + DisableCache: true, + Physical: inm, + }) + if err != nil { + t.Fatal(err) + } + + ln, addr := vaulthttp.TestServer(t, core) + defer ln.Close() + + client, err := vaultapi.NewClient(&vaultapi.Config{ + Address: addr, + }) + if err != nil { + t.Fatal(err) + } + + v := &Vault{} + v.initVaultClient() + v.vaultClient = client + + err = v.initializeVault() + if err != nil { + t.Fatal("InitializeVault: Error initializing Vault") + } +} diff --git a/sms-service/src/sms/handler/handler.go b/sms-service/src/sms/handler/handler.go index dbf3f93..7ce9e01 100644 --- a/sms-service/src/sms/handler/handler.go +++ b/sms-service/src/sms/handler/handler.go @@ -37,15 +37,13 @@ func (h handler) createSecretDomainHandler(w http.ResponseWriter, r *http.Reques var d smsbackend.SecretDomain err := json.NewDecoder(r.Body).Decode(&d) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } dom, err := h.secretBackend.CreateSecretDomain(d.Name) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -53,8 +51,7 @@ func (h handler) createSecretDomainHandler(w http.ResponseWriter, r *http.Reques w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) err = json.NewEncoder(w).Encode(dom) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -66,8 +63,7 @@ func (h handler) deleteSecretDomainHandler(w http.ResponseWriter, r *http.Reques domName := vars["domName"] err := h.secretBackend.DeleteSecretDomain(domName) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "DeleteSecretDomainHandler") != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -84,15 +80,13 @@ func (h handler) createSecretHandler(w http.ResponseWriter, r *http.Request) { // Get secrets to be stored from body var b smsbackend.Secret err := json.NewDecoder(r.Body).Decode(&b) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "CreateSecretHandler") != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } err = h.secretBackend.CreateSecret(domName, b) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "CreateSecretHandler") != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -107,16 +101,14 @@ func (h handler) getSecretHandler(w http.ResponseWriter, r *http.Request) { secName := vars["secretName"] sec, err := h.secretBackend.GetSecret(domName, secName) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "GetSecretHandler") != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } w.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w).Encode(sec) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "GetSecretHandler") != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -128,8 +120,7 @@ func (h handler) listSecretHandler(w http.ResponseWriter, r *http.Request) { domName := vars["domName"] secList, err := h.secretBackend.ListSecret(domName) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "ListSecretHandler") != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -143,8 +134,7 @@ func (h handler) listSecretHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w).Encode(retStruct) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "ListSecretHandler") != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -157,8 +147,7 @@ func (h handler) deleteSecretHandler(w http.ResponseWriter, r *http.Request) { secName := vars["secretName"] err := h.secretBackend.DeleteSecret(domName, secName) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "DeleteSecretHandler") != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -169,8 +158,7 @@ func (h handler) deleteSecretHandler(w http.ResponseWriter, r *http.Request) { // statusHandler returns information related to SMS and SMS backend services func (h handler) statusHandler(w http.ResponseWriter, r *http.Request) { s, err := h.secretBackend.GetStatus() - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "StatusHandler") != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -183,8 +171,7 @@ func (h handler) statusHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w).Encode(status) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "StatusHandler") != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -207,15 +194,13 @@ func (h handler) unsealHandler(w http.ResponseWriter, r *http.Request) { decoder := json.NewDecoder(r.Body) decoder.DisallowUnknownFields() err := decoder.Decode(&inp) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "UnsealHandler") != nil { http.Error(w, "Bad input JSON", http.StatusBadRequest) return } err = h.secretBackend.Unseal(inp.UnsealShard) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "UnsealHandler") != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -235,15 +220,13 @@ func (h handler) registerHandler(w http.ResponseWriter, r *http.Request) { decoder := json.NewDecoder(r.Body) decoder.DisallowUnknownFields() err := decoder.Decode(&inp) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "RegisterHandler") != nil { http.Error(w, "Bad input JSON", http.StatusBadRequest) return } sh, err := h.secretBackend.RegisterQuorum(inp.PGPKey) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "RegisterHandler") != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } @@ -257,8 +240,7 @@ func (h handler) registerHandler(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") err = json.NewEncoder(w).Encode(shStruct) - if err != nil { - smslogger.WriteError("Unable to encode response: " + err.Error()) + if smslogger.CheckError(err, "RegisterHandler") != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } diff --git a/sms-service/src/sms/handler/handler_test.go b/sms-service/src/sms/handler/handler_test.go index 52637f3..c1e55ed 100644 --- a/sms-service/src/sms/handler/handler_test.go +++ b/sms-service/src/sms/handler/handler_test.go @@ -48,7 +48,7 @@ func (b *TestBackend) Unseal(shard string) error { } func (b *TestBackend) RegisterQuorum(pgpkey string) (string, error) { - return "", nil + return "N8z4eD2Zgv0eDJrgkkUq3Lh5n2p6Y1Zsui1NIHePlLU=", nil } func (b *TestBackend) GetSecret(dom string, sec string) (smsbackend.Secret, error) { @@ -127,8 +127,49 @@ func TestStatusHandler(t *testing.T) { } } +func TestRegisterHandler(t *testing.T) { + body := `{ + "pgpkey":"asdasdasdasdgkjgljoiwera", + "quorumid":"123e4567-e89b-12d3-a456-426655440000" + }` + reader := strings.NewReader(body) + req, err := http.NewRequest("POST", "/v1/sms/quorum/register", reader) + if err != nil { + t.Fatal(err) + } + + rr := httptest.NewRecorder() + hr := http.HandlerFunc(h.registerHandler) + + hr.ServeHTTP(rr, req) + + ret := rr.Code + if ret != http.StatusOK { + t.Errorf("registerHandler returned wrong status code: %v vs %v", + ret, http.StatusOK) + } + + expected := struct { + Shard string `json:"shard"` + }{ + "N8z4eD2Zgv0eDJrgkkUq3Lh5n2p6Y1Zsui1NIHePlLU=", + } + got := struct { + Shard string `json:"shard"` + }{} + + json.NewDecoder(rr.Body).Decode(&got) + + if reflect.DeepEqual(expected, got) == false { + t.Errorf("statusHandler returned unexpected body: got %v vs %v", + rr.Body.String(), expected) + } +} + func TestUnsealHandler(t *testing.T) { - req, err := http.NewRequest("GET", "/v1/sms/quorum/unseal", nil) + body := `{"unsealshard":"N8z4eD2Zgv0eDJrgkkUq3Lh5n2p6Y1Zsui1NIHePlLU="}` + reader := strings.NewReader(body) + req, err := http.NewRequest("POST", "/v1/sms/quorum/unseal", reader) if err != nil { t.Fatal(err) } diff --git a/sms-service/src/sms/log/logger.go b/sms-service/src/sms/log/logger.go index 25da593..660f1ce 100644 --- a/sms-service/src/sms/log/logger.go +++ b/sms-service/src/sms/log/logger.go @@ -17,57 +17,82 @@ package log import ( + "fmt" "log" "os" ) -var errLogger *log.Logger -var warnLogger *log.Logger -var infoLogger *log.Logger +var errL, warnL, infoL *log.Logger +var stdErr, stdWarn, stdInfo *log.Logger // Init will be called by sms.go before any other packages use it func Init(filePath string) { + + stdErr = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags) + stdWarn = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags) + stdInfo = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags) + if filePath == "" { - errLogger = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags) - warnLogger = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags) - infoLogger = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags) + // We will just to std streams return } f, err := os.Create(filePath) if err != nil { - log.Println("Unable to create a log file") - log.Println(err) - errLogger = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags) - warnLogger = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags) - infoLogger = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags) - } else { - errLogger = log.New(f, "ERROR: ", log.Lshortfile|log.LstdFlags) - warnLogger = log.New(f, "WARNING: ", log.Lshortfile|log.LstdFlags) - infoLogger = log.New(f, "INFO: ", log.Lshortfile|log.LstdFlags) + stdErr.Println("Unable to create log file: " + err.Error()) + return } + + errL = log.New(f, "ERROR: ", log.Lshortfile|log.LstdFlags) + warnL = log.New(f, "WARNING: ", log.Lshortfile|log.LstdFlags) + infoL = log.New(f, "INFO: ", log.Lshortfile|log.LstdFlags) } // WriteError writes output to the writer we have -// defined durint its creation with ERROR prefix +// defined during its creation with ERROR prefix func WriteError(msg string) { - if errLogger != nil { - errLogger.Println(msg) + if errL != nil { + errL.Output(2, fmt.Sprintln(msg)) + } + if stdErr != nil { + stdErr.Output(2, fmt.Sprintln(msg)) } } // WriteWarn writes output to the writer we have -// defined durint its creation with WARNING prefix +// defined during its creation with WARNING prefix func WriteWarn(msg string) { - if warnLogger != nil { - warnLogger.Println(msg) + if warnL != nil { + warnL.Output(2, fmt.Sprintln(msg)) + } + if stdWarn != nil { + stdWarn.Output(2, fmt.Sprintln(msg)) } } // WriteInfo writes output to the writer we have -// defined durint its creation with INFO prefix +// defined during its creation with INFO prefix func WriteInfo(msg string) { - if infoLogger != nil { - infoLogger.Println(msg) + if infoL != nil { + infoL.Output(2, fmt.Sprintln(msg)) + } + if stdInfo != nil { + stdInfo.Output(2, fmt.Sprintln(msg)) + } +} + +//CheckError is a helper function to reduce +//repitition of error checkign blocks of code +func CheckError(err error, topic string) error { + if err != nil { + msg := topic + ": " + err.Error() + if errL != nil { + errL.Output(2, fmt.Sprintln(msg)) + } + if stdErr != nil { + stdErr.Output(2, fmt.Sprintln(msg)) + } + return err } + return nil } |