summaryrefslogtreecommitdiffstats
path: root/sms-service/src
diff options
context:
space:
mode:
Diffstat (limited to 'sms-service/src')
-rw-r--r--sms-service/src/quorumclient/quorumclient.go22
-rw-r--r--sms-service/src/sms/Gopkg.lock2
-rw-r--r--sms-service/src/sms/auth/auth.go82
-rw-r--r--sms-service/src/sms/backend/backend.go3
-rw-r--r--sms-service/src/sms/backend/vault.go105
-rw-r--r--sms-service/src/sms/backend/vault_test.go38
-rw-r--r--sms-service/src/sms/coverage.md41
-rw-r--r--sms-service/src/sms/handler/handler.go54
-rw-r--r--sms-service/src/sms/handler/handler_test.go45
-rw-r--r--sms-service/src/sms/log/logger.go73
10 files changed, 234 insertions, 231 deletions
diff --git a/sms-service/src/quorumclient/quorumclient.go b/sms-service/src/quorumclient/quorumclient.go
index 05fc967..dfa1a26 100644
--- a/sms-service/src/quorumclient/quorumclient.go
+++ b/sms-service/src/quorumclient/quorumclient.go
@@ -37,13 +37,13 @@ func loadPGPKeys(prKeyPath string, pbKeyPath string) (string, string, error) {
var pbkey, prkey string
generated := false
prkey, err := smsauth.ReadFromFile(prKeyPath)
- if err != nil {
- smslogger.WriteWarn("No Private Key found. Generating...")
+ if smslogger.CheckError(err, "LoadPGP Private Key") != nil {
+ smslogger.WriteInfo("No Private Key found. Generating...")
pbkey, prkey, _ = smsauth.GeneratePGPKeyPair()
generated = true
} else {
pbkey, err = smsauth.ReadFromFile(pbKeyPath)
- if err != nil {
+ if smslogger.CheckError(err, "LoadPGP Public Key") != nil {
smslogger.WriteWarn("No Public Key found. Generating...")
pbkey, prkey, _ = smsauth.GeneratePGPKeyPair()
generated = true
@@ -70,7 +70,7 @@ func main() {
prKeyPath := filepath.Join("auth", podName, "prkey")
shardPath := filepath.Join("auth", podName, "shard")
- smslogger.Init("")
+ smslogger.Init("quorum.log")
smslogger.WriteInfo("Starting Log for Quorum Client")
/*
@@ -80,7 +80,7 @@ func main() {
In Kubernetes, pod restarts will also change the hostname
*/
myID, err := smsauth.ReadFromFile(idFilePath)
- if err != nil {
+ if smslogger.CheckError(err, "Read ID") != nil {
smslogger.WriteWarn("Unable to find an ID for this client. Generating...")
myID, _ = uuid.GenerateUUID()
smsauth.WriteToFile(myID, idFilePath)
@@ -93,7 +93,7 @@ func main() {
*/
registrationDone := true
myShard, err := smsauth.ReadFromFile(shardPath)
- if err != nil {
+ if smslogger.CheckError(err, "Read Shard") != nil {
smslogger.WriteWarn("Unable to find a shard file. Registering with SMS...")
registrationDone = false
}
@@ -160,8 +160,7 @@ func main() {
//URL and Port is configured in config file
response, err := client.Get(cfg.BackEndURL + "/v1/sms/quorum/status")
- if err != nil {
- smslogger.WriteError("Unable to connect to SMS. Retrying...")
+ if smslogger.CheckError(err, "Connect to SMS") != nil {
continue
}
@@ -178,8 +177,7 @@ func main() {
if !registrationDone {
body := strings.NewReader(`{"pgpkey":"` + pbkey + `","quorumid":"` + myID + `"}`)
res, err := client.Post(cfg.BackEndURL+"/v1/sms/quorum/register", "application/json", body)
- if err != nil {
- smslogger.WriteError("Ran into error during registration. Retrying...")
+ if smslogger.CheckError(err, "Register with SMS") != nil {
continue
}
registrationDone = true
@@ -195,8 +193,8 @@ func main() {
body := strings.NewReader(`{"unsealshard":"` + decShard + `"}`)
//URL and PORT is configured via config file
response, err = client.Post(cfg.BackEndURL+"/v1/sms/quorum/unseal", "application/json", body)
- if err != nil {
- smslogger.WriteError("Error unsealing vault. Retrying... " + err.Error())
+ if smslogger.CheckError(err, "Unsealing Vault") != nil {
+ continue
}
}
}
diff --git a/sms-service/src/sms/Gopkg.lock b/sms-service/src/sms/Gopkg.lock
index c7684c7..2c09256 100644
--- a/sms-service/src/sms/Gopkg.lock
+++ b/sms-service/src/sms/Gopkg.lock
@@ -477,6 +477,6 @@
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
- inputs-digest = "d19e17a023506ab731b0f26c6fcfebe581d4d5194af094aecea5e72daddd3ead"
+ inputs-digest = "8280cde72a3ab78ad00d13c192de5920d188f3052f45884563896cab659469f9"
solver-name = "gps-cdcl"
solver-version = 1
diff --git a/sms-service/src/sms/auth/auth.go b/sms-service/src/sms/auth/auth.go
index 7172505..038e31d 100644
--- a/sms-service/src/sms/auth/auth.go
+++ b/sms-service/src/sms/auth/auth.go
@@ -29,39 +29,27 @@ import (
smslogger "sms/log"
)
-var tlsConfig *tls.Config
-
-func checkError(err error, topic string) error {
- if err != nil {
- smslogger.WriteError(topic + ": " + err.Error())
- return err
- }
-
- return nil
-}
-
// GetTLSConfig initializes a tlsConfig using the CA's certificate
// This config is then used to enable the server for mutual TLS
func GetTLSConfig(caCertFile string) (*tls.Config, error) {
+
// Initialize tlsConfig once
- if tlsConfig == nil {
- caCert, err := ioutil.ReadFile(caCertFile)
+ caCert, err := ioutil.ReadFile(caCertFile)
- if err != nil {
- return nil, err
- }
+ if err != nil {
+ return nil, err
+ }
- caCertPool := x509.NewCertPool()
- caCertPool.AppendCertsFromPEM(caCert)
+ caCertPool := x509.NewCertPool()
+ caCertPool.AppendCertsFromPEM(caCert)
- tlsConfig = &tls.Config{
- // Change to RequireAndVerify once we have mandatory certs
- ClientAuth: tls.VerifyClientCertIfGiven,
- ClientCAs: caCertPool,
- MinVersion: tls.VersionTLS12,
- }
- tlsConfig.BuildNameToCertificate()
+ tlsConfig := &tls.Config{
+ // Change to RequireAndVerify once we have mandatory certs
+ ClientAuth: tls.VerifyClientCertIfGiven,
+ ClientCAs: caCertPool,
+ MinVersion: tls.VersionTLS12,
}
+ tlsConfig.BuildNameToCertificate()
return tlsConfig, nil
}
@@ -70,22 +58,21 @@ func GetTLSConfig(caCertFile string) (*tls.Config, error) {
// A base64 encoded form of the public part of the entity
// A base64 encoded form of the private key
func GeneratePGPKeyPair() (string, string, error) {
+
var entity *openpgp.Entity
config := &packet.Config{
DefaultHash: crypto.SHA256,
}
entity, err := openpgp.NewEntity("aaf.sms.init", "PGP Key for unsealing", "", config)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create Entity") != nil {
return "", "", err
}
// Sign the identity in the entity
for _, id := range entity.Identities {
err = id.SelfSignature.SignUserId(id.UserId.Id, entity.PrimaryKey, entity.PrivateKey, nil)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Sign Entity") != nil {
return "", "", err
}
}
@@ -93,8 +80,7 @@ func GeneratePGPKeyPair() (string, string, error) {
// Sign the subkey in the entity
for _, subkey := range entity.Subkeys {
err := subkey.Sig.SignKey(subkey.PublicKey, entity.PrivateKey, nil)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Sign Subkey") != nil {
return "", "", err
}
}
@@ -113,32 +99,33 @@ func GeneratePGPKeyPair() (string, string, error) {
// EncryptPGPString takes data and a public key and encrypts using that
// public key
func EncryptPGPString(data string, pbKey string) (string, error) {
+
pbKeyBytes, err := base64.StdEncoding.DecodeString(pbKey)
- if checkError(err, "Decoding Base64 Public Key") != nil {
+ if smslogger.CheckError(err, "Decoding Base64 Public Key") != nil {
return "", err
}
dataBytes := []byte(data)
pbEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(pbKeyBytes)))
- if checkError(err, "Reading entity from PGP key") != nil {
+ if smslogger.CheckError(err, "Reading entity from PGP key") != nil {
return "", err
}
// encrypt string
buf := new(bytes.Buffer)
out, err := openpgp.Encrypt(buf, []*openpgp.Entity{pbEntity}, nil, nil, nil)
- if checkError(err, "Creating Encryption Pipe") != nil {
+ if smslogger.CheckError(err, "Creating Encryption Pipe") != nil {
return "", err
}
_, err = out.Write(dataBytes)
- if checkError(err, "Writing to Encryption Pipe") != nil {
+ if smslogger.CheckError(err, "Writing to Encryption Pipe") != nil {
return "", err
}
err = out.Close()
- if checkError(err, "Closing Encryption Pipe") != nil {
+ if smslogger.CheckError(err, "Closing Encryption Pipe") != nil {
return "", err
}
@@ -149,29 +136,26 @@ func EncryptPGPString(data string, pbKey string) (string, error) {
// DecryptPGPString decrypts a PGP encoded input string and returns
// a base64 representation of the decoded string
func DecryptPGPString(data string, prKey string) (string, error) {
+
// Convert private key to bytes from base64
prKeyBytes, err := base64.StdEncoding.DecodeString(prKey)
- if err != nil {
- smslogger.WriteError("Error Decoding base64 private key: " + err.Error())
+ if smslogger.CheckError(err, "Decoding Base64 Private Key") != nil {
return "", err
}
dataBytes, err := base64.StdEncoding.DecodeString(data)
- if err != nil {
- smslogger.WriteError("Error Decoding base64 data: " + err.Error())
+ if smslogger.CheckError(err, "Decoding base64 data") != nil {
return "", err
}
prEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(prKeyBytes)))
- if err != nil {
- smslogger.WriteError("Error reading entity from PGP key: " + err.Error())
+ if smslogger.CheckError(err, "Read Entity") != nil {
return "", err
}
prEntityList := &openpgp.EntityList{prEntity}
message, err := openpgp.ReadMessage(bytes.NewBuffer(dataBytes), prEntityList, nil, nil)
- if err != nil {
- smslogger.WriteError("Error Decrypting message: " + err.Error())
+ if smslogger.CheckError(err, "Decrypting Message") != nil {
return "", err
}
@@ -186,13 +170,10 @@ func DecryptPGPString(data string, prKey string) (string, error) {
func ReadFromFile(fileName string) (string, error) {
data, err := ioutil.ReadFile(fileName)
- if err != nil {
- smslogger.WriteError(err.Error())
- smslogger.WriteError("Cannot read file: " + fileName)
+ if smslogger.CheckError(err, "Read from file") != nil {
return "", err
}
return string(data), nil
-
}
// WriteToFile writes a PGP key into a file.
@@ -200,11 +181,8 @@ func ReadFromFile(fileName string) (string, error) {
func WriteToFile(data string, fileName string) error {
err := ioutil.WriteFile(fileName, []byte(data), 0600)
- if err != nil {
- smslogger.WriteError(err.Error())
- smslogger.WriteError("Cannot write to file: " + fileName)
+ if smslogger.CheckError(err, "Write to file") != nil {
return err
}
return nil
-
}
diff --git a/sms-service/src/sms/backend/backend.go b/sms-service/src/sms/backend/backend.go
index c137636..d7662ef 100644
--- a/sms-service/src/sms/backend/backend.go
+++ b/sms-service/src/sms/backend/backend.go
@@ -60,8 +60,7 @@ func InitSecretBackend() (SecretBackend, error) {
}
err := backendImpl.Init()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "InitSecretBackend") != nil {
return nil, err
}
diff --git a/sms-service/src/sms/backend/vault.go b/sms-service/src/sms/backend/vault.go
index e26baff..7fee097 100644
--- a/sms-service/src/sms/backend/vault.go
+++ b/sms-service/src/sms/backend/vault.go
@@ -56,9 +56,8 @@ func (v *Vault) initVaultClient() error {
vaultCFG := vaultapi.DefaultConfig()
vaultCFG.Address = v.vaultAddress
client, err := vaultapi.NewClient(vaultCFG)
- if err != nil {
- smslogger.WriteError(err.Error())
- return errors.New("Unable to create new vault client")
+ if smslogger.CheckError(err, "Create new vault client") != nil {
+ return err
}
v.initRoleDone = false
@@ -69,7 +68,6 @@ func (v *Vault) initVaultClient() error {
v.internalDomainMounted = false
v.prkey = ""
return nil
-
}
// Init will initialize the vault connection
@@ -84,8 +82,7 @@ func (v *Vault) Init() error {
v.initializeVault()
err := v.initRole()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "InitRole First Attempt") != nil {
smslogger.WriteInfo("InitRole will try again later")
}
@@ -94,10 +91,10 @@ func (v *Vault) Init() error {
// GetStatus returns the current seal status of vault
func (v *Vault) GetStatus() (bool, error) {
+
sys := v.vaultClient.Sys()
sealStatus, err := sys.SealStatus()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Getting Status") != nil {
return false, errors.New("Error getting status")
}
@@ -112,7 +109,7 @@ func (v *Vault) RegisterQuorum(pgpkey string) (string, error) {
defer v.Unlock()
if v.shards == nil {
- smslogger.WriteError("Invalid operation")
+ smslogger.WriteError("Invalid operation in RegisterQuorum")
return "", errors.New("Invalid operation")
}
// Pop the slice
@@ -133,10 +130,10 @@ func (v *Vault) RegisterQuorum(pgpkey string) (string, error) {
// Unseal is a passthrough API that allows any
// unseal or initialization processes for the backend
func (v *Vault) Unseal(shard string) error {
+
sys := v.vaultClient.Sys()
_, err := sys.Unseal(shard)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Unseal Operation") != nil {
return errors.New("Unable to execute unseal operation with specified shard")
}
@@ -147,17 +144,16 @@ func (v *Vault) Unseal(shard string) error {
// The secret itself is referenced via its name which translates to
// a mount path in vault
func (v *Vault) GetSecret(dom string, name string) (Secret, error) {
+
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Tocken Check") != nil {
return Secret{}, errors.New("Token check failed")
}
dom = v.vaultMountPrefix + "/" + dom
sec, err := v.vaultClient.Logical().Read(dom + "/" + name)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Read Secret") != nil {
return Secret{}, errors.New("Unable to read Secret at provided path")
}
@@ -173,17 +169,16 @@ func (v *Vault) GetSecret(dom string, name string) (Secret, error) {
// ListSecret returns a list of secret names on a particular domain
// The values of the secret are not returned
func (v *Vault) ListSecret(dom string) ([]string, error) {
+
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil {
return nil, errors.New("Token check failed")
}
dom = v.vaultMountPrefix + "/" + dom
sec, err := v.vaultClient.Logical().List(dom)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Read Secret") != nil {
return nil, errors.New("Unable to read Secret at provided path")
}
@@ -209,6 +204,7 @@ func (v *Vault) ListSecret(dom string) ([]string, error) {
// Mounts the internal Domain if its not already mounted
func (v *Vault) mountInternalDomain(name string) error {
+
if v.internalDomainMounted {
return nil
}
@@ -224,14 +220,13 @@ func (v *Vault) mountInternalDomain(name string) error {
}
err := v.vaultClient.Sys().Mount(mountPath, mountInput)
- if err != nil {
+ if smslogger.CheckError(err, "Mount internal Domain") != nil {
if strings.Contains(err.Error(), "existing mount") {
// It is already mounted
v.internalDomainMounted = true
return nil
}
// Ran into some other error mounting it.
- smslogger.WriteError(err.Error())
return errors.New("Unable to mount internal Domain")
}
@@ -242,16 +237,15 @@ func (v *Vault) mountInternalDomain(name string) error {
// Stores the UUID created for secretdomain in vault
// under v.vaultMountPrefix / smsinternal domain
func (v *Vault) storeUUID(uuid string, name string) error {
+
// Check if token is still valid
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil {
return errors.New("Token Check failed")
}
err = v.mountInternalDomain(v.internalDomain)
- if err != nil {
- smslogger.WriteError("Could not mount internal domain")
+ if smslogger.CheckError(err, "Mount Internal Domain") != nil {
return err
}
@@ -263,8 +257,7 @@ func (v *Vault) storeUUID(uuid string, name string) error {
}
err = v.CreateSecret(v.internalDomain, secret)
- if err != nil {
- smslogger.WriteError("Unable to write UUID to internal domain")
+ if smslogger.CheckError(err, "Write UUID to domain") != nil {
return err
}
@@ -273,10 +266,10 @@ func (v *Vault) storeUUID(uuid string, name string) error {
// CreateSecretDomain mounts the kv backend on a path with the given name
func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) {
+
// Check if token is still valid
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil {
return SecretDomain{}, errors.New("Token Check failed")
}
@@ -291,14 +284,13 @@ func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) {
}
err = v.vaultClient.Sys().Mount(mountPath, mountInput)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create Domain") != nil {
return SecretDomain{}, errors.New("Unable to create Secret Domain")
}
uuid, _ := uuid.GenerateUUID()
err = v.storeUUID(uuid, name)
- if err != nil {
+ if smslogger.CheckError(err, "Store UUID") != nil {
// Mount was successful at this point.
// Rollback the mount operation since we could not
// store the UUID for the mount.
@@ -312,9 +304,9 @@ func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) {
// CreateSecret creates a secret mounted on a particular domain name
// The secret itself is mounted on a path specified by name
func (v *Vault) CreateSecret(dom string, sec Secret) error {
+
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil {
return errors.New("Token check failed")
}
@@ -323,8 +315,7 @@ func (v *Vault) CreateSecret(dom string, sec Secret) error {
// Vault return is empty on successful write
// TODO: Check if values is not empty
_, err = v.vaultClient.Logical().Write(dom+"/"+sec.Name, sec.Values)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create Secret") != nil {
return errors.New("Unable to create Secret at provided path")
}
@@ -334,9 +325,9 @@ func (v *Vault) CreateSecret(dom string, sec Secret) error {
// DeleteSecretDomain deletes a secret domain which translates to
// an unmount operation on the given path in Vault
func (v *Vault) DeleteSecretDomain(name string) error {
+
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil {
return errors.New("Token Check Failed")
}
@@ -344,8 +335,7 @@ func (v *Vault) DeleteSecretDomain(name string) error {
mountPath := v.vaultMountPrefix + "/" + name
err = v.vaultClient.Sys().Unmount(mountPath)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Delete Domain") != nil {
return errors.New("Unable to delete domain specified")
}
@@ -356,8 +346,7 @@ func (v *Vault) DeleteSecretDomain(name string) error {
func (v *Vault) DeleteSecret(dom string, name string) error {
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil {
return errors.New("Token check failed")
}
@@ -365,8 +354,7 @@ func (v *Vault) DeleteSecret(dom string, name string) error {
// Vault return is empty on successful delete
_, err = v.vaultClient.Logical().Delete(dom + "/" + name)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Delete Secret") != nil {
return errors.New("Unable to delete Secret at provided path")
}
@@ -406,15 +394,13 @@ func (v *Vault) initRole() error {
rules := `path "sms/*" { capabilities = ["create", "read", "update", "delete", "list"] }
path "sys/mounts/sms*" { capabilities = ["update","delete","create"] }`
err := v.vaultClient.Sys().PutPolicy(v.policyName, rules)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Creating Policy") != nil {
return errors.New("Unable to create policy for approle creation")
}
//Check if applrole is mounted
authMounts, err := v.vaultClient.Sys().ListAuth()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Mount Auth Backend") != nil {
return errors.New("Unable to get mounted auth backends")
}
@@ -440,8 +426,7 @@ func (v *Vault) initRole() error {
// Create a role-id
v.vaultClient.Logical().Write("auth/approle/role/"+rName, data)
sec, err := v.vaultClient.Logical().Read("auth/approle/role/" + rName + "/role-id")
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create RoleID") != nil {
return errors.New("Unable to create role ID for approle")
}
v.roleID = sec.Data["role_id"].(string)
@@ -449,8 +434,7 @@ func (v *Vault) initRole() error {
// Create a secret-id to go with it
sec, err = v.vaultClient.Logical().Write("auth/approle/role/"+rName+"/secret-id",
map[string]interface{}{})
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create SecretID") != nil {
return errors.New("Unable to create secret ID for role")
}
@@ -462,8 +446,7 @@ func (v *Vault) initRole() error {
* using the unseal shards.
*/
err = v.vaultClient.Auth().Token().RevokeSelf(v.vaultToken)
- if err != nil {
- smslogger.WriteWarn(err.Error())
+ if smslogger.CheckError(err, "Revoke Root Token") != nil {
smslogger.WriteWarn("Unable to Revoke Token")
} else {
// Revoked successfully and clear it
@@ -481,6 +464,7 @@ func (v *Vault) initRole() error {
// Function checkToken() gets called multiple times to create
// temporary tokens
func (v *Vault) checkToken() error {
+
v.Lock()
defer v.Unlock()
@@ -501,8 +485,7 @@ func (v *Vault) checkToken() error {
// Create a temporary token using our roleID and secretID
out, err := v.vaultClient.Logical().Write("auth/approle/login",
map[string]interface{}{"role_id": v.roleID, "secret_id": v.secretID})
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create Temp Token") != nil {
return errors.New("Unable to create Temporary Token for Role")
}
@@ -516,11 +499,12 @@ func (v *Vault) checkToken() error {
// vaultInit() is used to initialize the vault in cases where it is not
// initialized. This happens once during intial bring up.
func (v *Vault) initializeVault() error {
+
// Check for vault init status and don't exit till it is initialized
for {
init, err := v.vaultClient.Sys().InitStatus()
- if err != nil {
- smslogger.WriteError("Unable to get initStatus, trying again in 10s: " + err.Error())
+ if smslogger.CheckError(err, "Get Vault Init Status") != nil {
+ smslogger.WriteInfo("Trying again in 10s...")
time.Sleep(time.Second * 10)
continue
}
@@ -545,7 +529,7 @@ func (v *Vault) initializeVault() error {
pbkey, prkey, err := smsauth.GeneratePGPKeyPair()
- if err != nil {
+ if smslogger.CheckError(err, "Generating PGP Keys") != nil {
smslogger.WriteError("Error Generating PGP Keys. Vault Init will not use encryption!")
} else {
initReq.PGPKeys = []string{pbkey, pbkey, pbkey}
@@ -553,8 +537,7 @@ func (v *Vault) initializeVault() error {
}
resp, err := v.vaultClient.Sys().Init(initReq)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Initialize Vault") != nil {
return errors.New("FATAL: Unable to initialize Vault")
}
diff --git a/sms-service/src/sms/backend/vault_test.go b/sms-service/src/sms/backend/vault_test.go
index 484c395..4862665 100644
--- a/sms-service/src/sms/backend/vault_test.go
+++ b/sms-service/src/sms/backend/vault_test.go
@@ -17,9 +17,11 @@
package backend
import (
+ vaultapi "github.com/hashicorp/vault/api"
credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
vaulthttp "github.com/hashicorp/vault/http"
vaultlogical "github.com/hashicorp/vault/logical"
+ vaultinmem "github.com/hashicorp/vault/physical/inmem"
vaulttesting "github.com/hashicorp/vault/vault"
"reflect"
smslog "sms/log"
@@ -229,3 +231,39 @@ func TestDeleteSecret(t *testing.T) {
t.Fatal("DeleteSecret: Error Creating secret")
}
}
+
+func TestInitializeVault(t *testing.T) {
+
+ inm, err := vaultinmem.NewInmem(nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ core, err := vaulttesting.NewCore(&vaulttesting.CoreConfig{
+ DisableMlock: true,
+ DisableCache: true,
+ Physical: inm,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ln, addr := vaulthttp.TestServer(t, core)
+ defer ln.Close()
+
+ client, err := vaultapi.NewClient(&vaultapi.Config{
+ Address: addr,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ v := &Vault{}
+ v.initVaultClient()
+ v.vaultClient = client
+
+ err = v.initializeVault()
+ if err != nil {
+ t.Fatal("InitializeVault: Error initializing Vault")
+ }
+}
diff --git a/sms-service/src/sms/coverage.md b/sms-service/src/sms/coverage.md
deleted file mode 100644
index 6168342..0000000
--- a/sms-service/src/sms/coverage.md
+++ /dev/null
@@ -1,41 +0,0 @@
-## Code Coverage Reports for Golang Applications ##
-
-This document covers how to generate HTML Code Coverage Reports for
-Golang Applications.
-
-#### Generate a test executable which calls your main()
-
-```sh
-$ go test -c -covermode=count -coverpkg ./...
-```
-
-#### Run the generated application to produce a new coverage report
-
-```sh
-$ ./sms.test -test.run "^TestMain$" -test.coverprofile=coverage.cov
-```
-
-#### Run your unit tests to produce their coverage report
-
-```sh
-$ go test -test.covermode=count -test.coverprofile=unit.out ./...
-```
-
-#### Merge the two coverage Reports
-
-```sh
-$ go get github.com/wadey/gocovmerge
-$ gocovmerge unit.out coverage.cov > all.out
-```
-
-#### Generate HTML Report
-
-```sh
-$ go tool cover -html all.out -o coverage.html
-```
-
-#### Generate Function Report
-
-```sh
-$ go tool cover -func all.out
-``` \ No newline at end of file
diff --git a/sms-service/src/sms/handler/handler.go b/sms-service/src/sms/handler/handler.go
index dbf3f93..7ce9e01 100644
--- a/sms-service/src/sms/handler/handler.go
+++ b/sms-service/src/sms/handler/handler.go
@@ -37,15 +37,13 @@ func (h handler) createSecretDomainHandler(w http.ResponseWriter, r *http.Reques
var d smsbackend.SecretDomain
err := json.NewDecoder(r.Body).Decode(&d)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
dom, err := h.secretBackend.CreateSecretDomain(d.Name)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -53,8 +51,7 @@ func (h handler) createSecretDomainHandler(w http.ResponseWriter, r *http.Reques
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
err = json.NewEncoder(w).Encode(dom)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -66,8 +63,7 @@ func (h handler) deleteSecretDomainHandler(w http.ResponseWriter, r *http.Reques
domName := vars["domName"]
err := h.secretBackend.DeleteSecretDomain(domName)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "DeleteSecretDomainHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -84,15 +80,13 @@ func (h handler) createSecretHandler(w http.ResponseWriter, r *http.Request) {
// Get secrets to be stored from body
var b smsbackend.Secret
err := json.NewDecoder(r.Body).Decode(&b)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "CreateSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
err = h.secretBackend.CreateSecret(domName, b)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "CreateSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -107,16 +101,14 @@ func (h handler) getSecretHandler(w http.ResponseWriter, r *http.Request) {
secName := vars["secretName"]
sec, err := h.secretBackend.GetSecret(domName, secName)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "GetSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(sec)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "GetSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -128,8 +120,7 @@ func (h handler) listSecretHandler(w http.ResponseWriter, r *http.Request) {
domName := vars["domName"]
secList, err := h.secretBackend.ListSecret(domName)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "ListSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -143,8 +134,7 @@ func (h handler) listSecretHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(retStruct)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "ListSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -157,8 +147,7 @@ func (h handler) deleteSecretHandler(w http.ResponseWriter, r *http.Request) {
secName := vars["secretName"]
err := h.secretBackend.DeleteSecret(domName, secName)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "DeleteSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -169,8 +158,7 @@ func (h handler) deleteSecretHandler(w http.ResponseWriter, r *http.Request) {
// statusHandler returns information related to SMS and SMS backend services
func (h handler) statusHandler(w http.ResponseWriter, r *http.Request) {
s, err := h.secretBackend.GetStatus()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "StatusHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -183,8 +171,7 @@ func (h handler) statusHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(status)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "StatusHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -207,15 +194,13 @@ func (h handler) unsealHandler(w http.ResponseWriter, r *http.Request) {
decoder := json.NewDecoder(r.Body)
decoder.DisallowUnknownFields()
err := decoder.Decode(&inp)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "UnsealHandler") != nil {
http.Error(w, "Bad input JSON", http.StatusBadRequest)
return
}
err = h.secretBackend.Unseal(inp.UnsealShard)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "UnsealHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -235,15 +220,13 @@ func (h handler) registerHandler(w http.ResponseWriter, r *http.Request) {
decoder := json.NewDecoder(r.Body)
decoder.DisallowUnknownFields()
err := decoder.Decode(&inp)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "RegisterHandler") != nil {
http.Error(w, "Bad input JSON", http.StatusBadRequest)
return
}
sh, err := h.secretBackend.RegisterQuorum(inp.PGPKey)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "RegisterHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -257,8 +240,7 @@ func (h handler) registerHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(shStruct)
- if err != nil {
- smslogger.WriteError("Unable to encode response: " + err.Error())
+ if smslogger.CheckError(err, "RegisterHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
diff --git a/sms-service/src/sms/handler/handler_test.go b/sms-service/src/sms/handler/handler_test.go
index 52637f3..c1e55ed 100644
--- a/sms-service/src/sms/handler/handler_test.go
+++ b/sms-service/src/sms/handler/handler_test.go
@@ -48,7 +48,7 @@ func (b *TestBackend) Unseal(shard string) error {
}
func (b *TestBackend) RegisterQuorum(pgpkey string) (string, error) {
- return "", nil
+ return "N8z4eD2Zgv0eDJrgkkUq3Lh5n2p6Y1Zsui1NIHePlLU=", nil
}
func (b *TestBackend) GetSecret(dom string, sec string) (smsbackend.Secret, error) {
@@ -127,8 +127,49 @@ func TestStatusHandler(t *testing.T) {
}
}
+func TestRegisterHandler(t *testing.T) {
+ body := `{
+ "pgpkey":"asdasdasdasdgkjgljoiwera",
+ "quorumid":"123e4567-e89b-12d3-a456-426655440000"
+ }`
+ reader := strings.NewReader(body)
+ req, err := http.NewRequest("POST", "/v1/sms/quorum/register", reader)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ rr := httptest.NewRecorder()
+ hr := http.HandlerFunc(h.registerHandler)
+
+ hr.ServeHTTP(rr, req)
+
+ ret := rr.Code
+ if ret != http.StatusOK {
+ t.Errorf("registerHandler returned wrong status code: %v vs %v",
+ ret, http.StatusOK)
+ }
+
+ expected := struct {
+ Shard string `json:"shard"`
+ }{
+ "N8z4eD2Zgv0eDJrgkkUq3Lh5n2p6Y1Zsui1NIHePlLU=",
+ }
+ got := struct {
+ Shard string `json:"shard"`
+ }{}
+
+ json.NewDecoder(rr.Body).Decode(&got)
+
+ if reflect.DeepEqual(expected, got) == false {
+ t.Errorf("statusHandler returned unexpected body: got %v vs %v",
+ rr.Body.String(), expected)
+ }
+}
+
func TestUnsealHandler(t *testing.T) {
- req, err := http.NewRequest("GET", "/v1/sms/quorum/unseal", nil)
+ body := `{"unsealshard":"N8z4eD2Zgv0eDJrgkkUq3Lh5n2p6Y1Zsui1NIHePlLU="}`
+ reader := strings.NewReader(body)
+ req, err := http.NewRequest("POST", "/v1/sms/quorum/unseal", reader)
if err != nil {
t.Fatal(err)
}
diff --git a/sms-service/src/sms/log/logger.go b/sms-service/src/sms/log/logger.go
index 25da593..660f1ce 100644
--- a/sms-service/src/sms/log/logger.go
+++ b/sms-service/src/sms/log/logger.go
@@ -17,57 +17,82 @@
package log
import (
+ "fmt"
"log"
"os"
)
-var errLogger *log.Logger
-var warnLogger *log.Logger
-var infoLogger *log.Logger
+var errL, warnL, infoL *log.Logger
+var stdErr, stdWarn, stdInfo *log.Logger
// Init will be called by sms.go before any other packages use it
func Init(filePath string) {
+
+ stdErr = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags)
+ stdWarn = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags)
+ stdInfo = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags)
+
if filePath == "" {
- errLogger = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags)
- warnLogger = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags)
- infoLogger = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags)
+ // We will just to std streams
return
}
f, err := os.Create(filePath)
if err != nil {
- log.Println("Unable to create a log file")
- log.Println(err)
- errLogger = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags)
- warnLogger = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags)
- infoLogger = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags)
- } else {
- errLogger = log.New(f, "ERROR: ", log.Lshortfile|log.LstdFlags)
- warnLogger = log.New(f, "WARNING: ", log.Lshortfile|log.LstdFlags)
- infoLogger = log.New(f, "INFO: ", log.Lshortfile|log.LstdFlags)
+ stdErr.Println("Unable to create log file: " + err.Error())
+ return
}
+
+ errL = log.New(f, "ERROR: ", log.Lshortfile|log.LstdFlags)
+ warnL = log.New(f, "WARNING: ", log.Lshortfile|log.LstdFlags)
+ infoL = log.New(f, "INFO: ", log.Lshortfile|log.LstdFlags)
}
// WriteError writes output to the writer we have
-// defined durint its creation with ERROR prefix
+// defined during its creation with ERROR prefix
func WriteError(msg string) {
- if errLogger != nil {
- errLogger.Println(msg)
+ if errL != nil {
+ errL.Output(2, fmt.Sprintln(msg))
+ }
+ if stdErr != nil {
+ stdErr.Output(2, fmt.Sprintln(msg))
}
}
// WriteWarn writes output to the writer we have
-// defined durint its creation with WARNING prefix
+// defined during its creation with WARNING prefix
func WriteWarn(msg string) {
- if warnLogger != nil {
- warnLogger.Println(msg)
+ if warnL != nil {
+ warnL.Output(2, fmt.Sprintln(msg))
+ }
+ if stdWarn != nil {
+ stdWarn.Output(2, fmt.Sprintln(msg))
}
}
// WriteInfo writes output to the writer we have
-// defined durint its creation with INFO prefix
+// defined during its creation with INFO prefix
func WriteInfo(msg string) {
- if infoLogger != nil {
- infoLogger.Println(msg)
+ if infoL != nil {
+ infoL.Output(2, fmt.Sprintln(msg))
+ }
+ if stdInfo != nil {
+ stdInfo.Output(2, fmt.Sprintln(msg))
+ }
+}
+
+//CheckError is a helper function to reduce
+//repitition of error checkign blocks of code
+func CheckError(err error, topic string) error {
+ if err != nil {
+ msg := topic + ": " + err.Error()
+ if errL != nil {
+ errL.Output(2, fmt.Sprintln(msg))
+ }
+ if stdErr != nil {
+ stdErr.Output(2, fmt.Sprintln(msg))
+ }
+ return err
}
+ return nil
}