summaryrefslogtreecommitdiffstats
path: root/sms-service/src
diff options
context:
space:
mode:
authorKiran Kamineni <kiran.k.kamineni@intel.com>2018-04-19 21:27:01 -0700
committerKiran Kamineni <kiran.k.kamineni@intel.com>2018-04-20 14:48:26 -0700
commit7597c1552d636712391d7269d0373747384ced0d (patch)
treeead1edb21947111935417432231d0bab6a1f058c /sms-service/src
parent333da2a55ef9535a32d90e249ab7f3842944db6a (diff)
Refactor logger and use it everywhere
Refactored the logger to print the right line number. This is done by using the runtime.caller function within the logger.output function Issue-ID: AAF-257 Change-Id: Ie26de43ca74c71f382d3b5f93ebd4eaf6d51e2b4 Signed-off-by: Kiran Kamineni <kiran.k.kamineni@intel.com>
Diffstat (limited to 'sms-service/src')
-rw-r--r--sms-service/src/quorumclient/quorumclient.go22
-rw-r--r--sms-service/src/sms/Gopkg.lock2
-rw-r--r--sms-service/src/sms/auth/auth.go82
-rw-r--r--sms-service/src/sms/backend/backend.go3
-rw-r--r--sms-service/src/sms/backend/vault.go105
-rw-r--r--sms-service/src/sms/backend/vault_test.go38
-rw-r--r--sms-service/src/sms/coverage.md41
-rw-r--r--sms-service/src/sms/handler/handler.go54
-rw-r--r--sms-service/src/sms/handler/handler_test.go45
-rw-r--r--sms-service/src/sms/log/logger.go73
10 files changed, 234 insertions, 231 deletions
diff --git a/sms-service/src/quorumclient/quorumclient.go b/sms-service/src/quorumclient/quorumclient.go
index 05fc967..dfa1a26 100644
--- a/sms-service/src/quorumclient/quorumclient.go
+++ b/sms-service/src/quorumclient/quorumclient.go
@@ -37,13 +37,13 @@ func loadPGPKeys(prKeyPath string, pbKeyPath string) (string, string, error) {
var pbkey, prkey string
generated := false
prkey, err := smsauth.ReadFromFile(prKeyPath)
- if err != nil {
- smslogger.WriteWarn("No Private Key found. Generating...")
+ if smslogger.CheckError(err, "LoadPGP Private Key") != nil {
+ smslogger.WriteInfo("No Private Key found. Generating...")
pbkey, prkey, _ = smsauth.GeneratePGPKeyPair()
generated = true
} else {
pbkey, err = smsauth.ReadFromFile(pbKeyPath)
- if err != nil {
+ if smslogger.CheckError(err, "LoadPGP Public Key") != nil {
smslogger.WriteWarn("No Public Key found. Generating...")
pbkey, prkey, _ = smsauth.GeneratePGPKeyPair()
generated = true
@@ -70,7 +70,7 @@ func main() {
prKeyPath := filepath.Join("auth", podName, "prkey")
shardPath := filepath.Join("auth", podName, "shard")
- smslogger.Init("")
+ smslogger.Init("quorum.log")
smslogger.WriteInfo("Starting Log for Quorum Client")
/*
@@ -80,7 +80,7 @@ func main() {
In Kubernetes, pod restarts will also change the hostname
*/
myID, err := smsauth.ReadFromFile(idFilePath)
- if err != nil {
+ if smslogger.CheckError(err, "Read ID") != nil {
smslogger.WriteWarn("Unable to find an ID for this client. Generating...")
myID, _ = uuid.GenerateUUID()
smsauth.WriteToFile(myID, idFilePath)
@@ -93,7 +93,7 @@ func main() {
*/
registrationDone := true
myShard, err := smsauth.ReadFromFile(shardPath)
- if err != nil {
+ if smslogger.CheckError(err, "Read Shard") != nil {
smslogger.WriteWarn("Unable to find a shard file. Registering with SMS...")
registrationDone = false
}
@@ -160,8 +160,7 @@ func main() {
//URL and Port is configured in config file
response, err := client.Get(cfg.BackEndURL + "/v1/sms/quorum/status")
- if err != nil {
- smslogger.WriteError("Unable to connect to SMS. Retrying...")
+ if smslogger.CheckError(err, "Connect to SMS") != nil {
continue
}
@@ -178,8 +177,7 @@ func main() {
if !registrationDone {
body := strings.NewReader(`{"pgpkey":"` + pbkey + `","quorumid":"` + myID + `"}`)
res, err := client.Post(cfg.BackEndURL+"/v1/sms/quorum/register", "application/json", body)
- if err != nil {
- smslogger.WriteError("Ran into error during registration. Retrying...")
+ if smslogger.CheckError(err, "Register with SMS") != nil {
continue
}
registrationDone = true
@@ -195,8 +193,8 @@ func main() {
body := strings.NewReader(`{"unsealshard":"` + decShard + `"}`)
//URL and PORT is configured via config file
response, err = client.Post(cfg.BackEndURL+"/v1/sms/quorum/unseal", "application/json", body)
- if err != nil {
- smslogger.WriteError("Error unsealing vault. Retrying... " + err.Error())
+ if smslogger.CheckError(err, "Unsealing Vault") != nil {
+ continue
}
}
}
diff --git a/sms-service/src/sms/Gopkg.lock b/sms-service/src/sms/Gopkg.lock
index c7684c7..2c09256 100644
--- a/sms-service/src/sms/Gopkg.lock
+++ b/sms-service/src/sms/Gopkg.lock
@@ -477,6 +477,6 @@
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
- inputs-digest = "d19e17a023506ab731b0f26c6fcfebe581d4d5194af094aecea5e72daddd3ead"
+ inputs-digest = "8280cde72a3ab78ad00d13c192de5920d188f3052f45884563896cab659469f9"
solver-name = "gps-cdcl"
solver-version = 1
diff --git a/sms-service/src/sms/auth/auth.go b/sms-service/src/sms/auth/auth.go
index 7172505..038e31d 100644
--- a/sms-service/src/sms/auth/auth.go
+++ b/sms-service/src/sms/auth/auth.go
@@ -29,39 +29,27 @@ import (
smslogger "sms/log"
)
-var tlsConfig *tls.Config
-
-func checkError(err error, topic string) error {
- if err != nil {
- smslogger.WriteError(topic + ": " + err.Error())
- return err
- }
-
- return nil
-}
-
// GetTLSConfig initializes a tlsConfig using the CA's certificate
// This config is then used to enable the server for mutual TLS
func GetTLSConfig(caCertFile string) (*tls.Config, error) {
+
// Initialize tlsConfig once
- if tlsConfig == nil {
- caCert, err := ioutil.ReadFile(caCertFile)
+ caCert, err := ioutil.ReadFile(caCertFile)
- if err != nil {
- return nil, err
- }
+ if err != nil {
+ return nil, err
+ }
- caCertPool := x509.NewCertPool()
- caCertPool.AppendCertsFromPEM(caCert)
+ caCertPool := x509.NewCertPool()
+ caCertPool.AppendCertsFromPEM(caCert)
- tlsConfig = &tls.Config{
- // Change to RequireAndVerify once we have mandatory certs
- ClientAuth: tls.VerifyClientCertIfGiven,
- ClientCAs: caCertPool,
- MinVersion: tls.VersionTLS12,
- }
- tlsConfig.BuildNameToCertificate()
+ tlsConfig := &tls.Config{
+ // Change to RequireAndVerify once we have mandatory certs
+ ClientAuth: tls.VerifyClientCertIfGiven,
+ ClientCAs: caCertPool,
+ MinVersion: tls.VersionTLS12,
}
+ tlsConfig.BuildNameToCertificate()
return tlsConfig, nil
}
@@ -70,22 +58,21 @@ func GetTLSConfig(caCertFile string) (*tls.Config, error) {
// A base64 encoded form of the public part of the entity
// A base64 encoded form of the private key
func GeneratePGPKeyPair() (string, string, error) {
+
var entity *openpgp.Entity
config := &packet.Config{
DefaultHash: crypto.SHA256,
}
entity, err := openpgp.NewEntity("aaf.sms.init", "PGP Key for unsealing", "", config)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create Entity") != nil {
return "", "", err
}
// Sign the identity in the entity
for _, id := range entity.Identities {
err = id.SelfSignature.SignUserId(id.UserId.Id, entity.PrimaryKey, entity.PrivateKey, nil)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Sign Entity") != nil {
return "", "", err
}
}
@@ -93,8 +80,7 @@ func GeneratePGPKeyPair() (string, string, error) {
// Sign the subkey in the entity
for _, subkey := range entity.Subkeys {
err := subkey.Sig.SignKey(subkey.PublicKey, entity.PrivateKey, nil)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Sign Subkey") != nil {
return "", "", err
}
}
@@ -113,32 +99,33 @@ func GeneratePGPKeyPair() (string, string, error) {
// EncryptPGPString takes data and a public key and encrypts using that
// public key
func EncryptPGPString(data string, pbKey string) (string, error) {
+
pbKeyBytes, err := base64.StdEncoding.DecodeString(pbKey)
- if checkError(err, "Decoding Base64 Public Key") != nil {
+ if smslogger.CheckError(err, "Decoding Base64 Public Key") != nil {
return "", err
}
dataBytes := []byte(data)
pbEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(pbKeyBytes)))
- if checkError(err, "Reading entity from PGP key") != nil {
+ if smslogger.CheckError(err, "Reading entity from PGP key") != nil {
return "", err
}
// encrypt string
buf := new(bytes.Buffer)
out, err := openpgp.Encrypt(buf, []*openpgp.Entity{pbEntity}, nil, nil, nil)
- if checkError(err, "Creating Encryption Pipe") != nil {
+ if smslogger.CheckError(err, "Creating Encryption Pipe") != nil {
return "", err
}
_, err = out.Write(dataBytes)
- if checkError(err, "Writing to Encryption Pipe") != nil {
+ if smslogger.CheckError(err, "Writing to Encryption Pipe") != nil {
return "", err
}
err = out.Close()
- if checkError(err, "Closing Encryption Pipe") != nil {
+ if smslogger.CheckError(err, "Closing Encryption Pipe") != nil {
return "", err
}
@@ -149,29 +136,26 @@ func EncryptPGPString(data string, pbKey string) (string, error) {
// DecryptPGPString decrypts a PGP encoded input string and returns
// a base64 representation of the decoded string
func DecryptPGPString(data string, prKey string) (string, error) {
+
// Convert private key to bytes from base64
prKeyBytes, err := base64.StdEncoding.DecodeString(prKey)
- if err != nil {
- smslogger.WriteError("Error Decoding base64 private key: " + err.Error())
+ if smslogger.CheckError(err, "Decoding Base64 Private Key") != nil {
return "", err
}
dataBytes, err := base64.StdEncoding.DecodeString(data)
- if err != nil {
- smslogger.WriteError("Error Decoding base64 data: " + err.Error())
+ if smslogger.CheckError(err, "Decoding base64 data") != nil {
return "", err
}
prEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(prKeyBytes)))
- if err != nil {
- smslogger.WriteError("Error reading entity from PGP key: " + err.Error())
+ if smslogger.CheckError(err, "Read Entity") != nil {
return "", err
}
prEntityList := &openpgp.EntityList{prEntity}
message, err := openpgp.ReadMessage(bytes.NewBuffer(dataBytes), prEntityList, nil, nil)
- if err != nil {
- smslogger.WriteError("Error Decrypting message: " + err.Error())
+ if smslogger.CheckError(err, "Decrypting Message") != nil {
return "", err
}
@@ -186,13 +170,10 @@ func DecryptPGPString(data string, prKey string) (string, error) {
func ReadFromFile(fileName string) (string, error) {
data, err := ioutil.ReadFile(fileName)
- if err != nil {
- smslogger.WriteError(err.Error())
- smslogger.WriteError("Cannot read file: " + fileName)
+ if smslogger.CheckError(err, "Read from file") != nil {
return "", err
}
return string(data), nil
-
}
// WriteToFile writes a PGP key into a file.
@@ -200,11 +181,8 @@ func ReadFromFile(fileName string) (string, error) {
func WriteToFile(data string, fileName string) error {
err := ioutil.WriteFile(fileName, []byte(data), 0600)
- if err != nil {
- smslogger.WriteError(err.Error())
- smslogger.WriteError("Cannot write to file: " + fileName)
+ if smslogger.CheckError(err, "Write to file") != nil {
return err
}
return nil
-
}
diff --git a/sms-service/src/sms/backend/backend.go b/sms-service/src/sms/backend/backend.go
index c137636..d7662ef 100644
--- a/sms-service/src/sms/backend/backend.go
+++ b/sms-service/src/sms/backend/backend.go
@@ -60,8 +60,7 @@ func InitSecretBackend() (SecretBackend, error) {
}
err := backendImpl.Init()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "InitSecretBackend") != nil {
return nil, err
}
diff --git a/sms-service/src/sms/backend/vault.go b/sms-service/src/sms/backend/vault.go
index e26baff..7fee097 100644
--- a/sms-service/src/sms/backend/vault.go
+++ b/sms-service/src/sms/backend/vault.go
@@ -56,9 +56,8 @@ func (v *Vault) initVaultClient() error {
vaultCFG := vaultapi.DefaultConfig()
vaultCFG.Address = v.vaultAddress
client, err := vaultapi.NewClient(vaultCFG)
- if err != nil {
- smslogger.WriteError(err.Error())
- return errors.New("Unable to create new vault client")
+ if smslogger.CheckError(err, "Create new vault client") != nil {
+ return err
}
v.initRoleDone = false
@@ -69,7 +68,6 @@ func (v *Vault) initVaultClient() error {
v.internalDomainMounted = false
v.prkey = ""
return nil
-
}
// Init will initialize the vault connection
@@ -84,8 +82,7 @@ func (v *Vault) Init() error {
v.initializeVault()
err := v.initRole()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "InitRole First Attempt") != nil {
smslogger.WriteInfo("InitRole will try again later")
}
@@ -94,10 +91,10 @@ func (v *Vault) Init() error {
// GetStatus returns the current seal status of vault
func (v *Vault) GetStatus() (bool, error) {
+
sys := v.vaultClient.Sys()
sealStatus, err := sys.SealStatus()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Getting Status") != nil {
return false, errors.New("Error getting status")
}
@@ -112,7 +109,7 @@ func (v *Vault) RegisterQuorum(pgpkey string) (string, error) {
defer v.Unlock()
if v.shards == nil {
- smslogger.WriteError("Invalid operation")
+ smslogger.WriteError("Invalid operation in RegisterQuorum")
return "", errors.New("Invalid operation")
}
// Pop the slice
@@ -133,10 +130,10 @@ func (v *Vault) RegisterQuorum(pgpkey string) (string, error) {
// Unseal is a passthrough API that allows any
// unseal or initialization processes for the backend
func (v *Vault) Unseal(shard string) error {
+
sys := v.vaultClient.Sys()
_, err := sys.Unseal(shard)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Unseal Operation") != nil {
return errors.New("Unable to execute unseal operation with specified shard")
}
@@ -147,17 +144,16 @@ func (v *Vault) Unseal(shard string) error {
// The secret itself is referenced via its name which translates to
// a mount path in vault
func (v *Vault) GetSecret(dom string, name string) (Secret, error) {
+
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Tocken Check") != nil {
return Secret{}, errors.New("Token check failed")
}
dom = v.vaultMountPrefix + "/" + dom
sec, err := v.vaultClient.Logical().Read(dom + "/" + name)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Read Secret") != nil {
return Secret{}, errors.New("Unable to read Secret at provided path")
}
@@ -173,17 +169,16 @@ func (v *Vault) GetSecret(dom string, name string) (Secret, error) {
// ListSecret returns a list of secret names on a particular domain
// The values of the secret are not returned
func (v *Vault) ListSecret(dom string) ([]string, error) {
+
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil {
return nil, errors.New("Token check failed")
}
dom = v.vaultMountPrefix + "/" + dom
sec, err := v.vaultClient.Logical().List(dom)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Read Secret") != nil {
return nil, errors.New("Unable to read Secret at provided path")
}
@@ -209,6 +204,7 @@ func (v *Vault) ListSecret(dom string) ([]string, error) {
// Mounts the internal Domain if its not already mounted
func (v *Vault) mountInternalDomain(name string) error {
+
if v.internalDomainMounted {
return nil
}
@@ -224,14 +220,13 @@ func (v *Vault) mountInternalDomain(name string) error {
}
err := v.vaultClient.Sys().Mount(mountPath, mountInput)
- if err != nil {
+ if smslogger.CheckError(err, "Mount internal Domain") != nil {
if strings.Contains(err.Error(), "existing mount") {
// It is already mounted
v.internalDomainMounted = true
return nil
}
// Ran into some other error mounting it.
- smslogger.WriteError(err.Error())
return errors.New("Unable to mount internal Domain")
}
@@ -242,16 +237,15 @@ func (v *Vault) mountInternalDomain(name string) error {
// Stores the UUID created for secretdomain in vault
// under v.vaultMountPrefix / smsinternal domain
func (v *Vault) storeUUID(uuid string, name string) error {
+
// Check if token is still valid
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil {
return errors.New("Token Check failed")
}
err = v.mountInternalDomain(v.internalDomain)
- if err != nil {
- smslogger.WriteError("Could not mount internal domain")
+ if smslogger.CheckError(err, "Mount Internal Domain") != nil {
return err
}
@@ -263,8 +257,7 @@ func (v *Vault) storeUUID(uuid string, name string) error {
}
err = v.CreateSecret(v.internalDomain, secret)
- if err != nil {
- smslogger.WriteError("Unable to write UUID to internal domain")
+ if smslogger.CheckError(err, "Write UUID to domain") != nil {
return err
}
@@ -273,10 +266,10 @@ func (v *Vault) storeUUID(uuid string, name string) error {
// CreateSecretDomain mounts the kv backend on a path with the given name
func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) {
+
// Check if token is still valid
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil {
return SecretDomain{}, errors.New("Token Check failed")
}
@@ -291,14 +284,13 @@ func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) {
}
err = v.vaultClient.Sys().Mount(mountPath, mountInput)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create Domain") != nil {
return SecretDomain{}, errors.New("Unable to create Secret Domain")
}
uuid, _ := uuid.GenerateUUID()
err = v.storeUUID(uuid, name)
- if err != nil {
+ if smslogger.CheckError(err, "Store UUID") != nil {
// Mount was successful at this point.
// Rollback the mount operation since we could not
// store the UUID for the mount.
@@ -312,9 +304,9 @@ func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) {
// CreateSecret creates a secret mounted on a particular domain name
// The secret itself is mounted on a path specified by name
func (v *Vault) CreateSecret(dom string, sec Secret) error {
+
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil {
return errors.New("Token check failed")
}
@@ -323,8 +315,7 @@ func (v *Vault) CreateSecret(dom string, sec Secret) error {
// Vault return is empty on successful write
// TODO: Check if values is not empty
_, err = v.vaultClient.Logical().Write(dom+"/"+sec.Name, sec.Values)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create Secret") != nil {
return errors.New("Unable to create Secret at provided path")
}
@@ -334,9 +325,9 @@ func (v *Vault) CreateSecret(dom string, sec Secret) error {
// DeleteSecretDomain deletes a secret domain which translates to
// an unmount operation on the given path in Vault
func (v *Vault) DeleteSecretDomain(name string) error {
+
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil {
return errors.New("Token Check Failed")
}
@@ -344,8 +335,7 @@ func (v *Vault) DeleteSecretDomain(name string) error {
mountPath := v.vaultMountPrefix + "/" + name
err = v.vaultClient.Sys().Unmount(mountPath)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Delete Domain") != nil {
return errors.New("Unable to delete domain specified")
}
@@ -356,8 +346,7 @@ func (v *Vault) DeleteSecretDomain(name string) error {
func (v *Vault) DeleteSecret(dom string, name string) error {
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil {
return errors.New("Token check failed")
}
@@ -365,8 +354,7 @@ func (v *Vault) DeleteSecret(dom string, name string) error {
// Vault return is empty on successful delete
_, err = v.vaultClient.Logical().Delete(dom + "/" + name)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Delete Secret") != nil {
return errors.New("Unable to delete Secret at provided path")
}
@@ -406,15 +394,13 @@ func (v *Vault) initRole() error {
rules := `path "sms/*" { capabilities = ["create", "read", "update", "delete", "list"] }
path "sys/mounts/sms*" { capabilities = ["update","delete","create"] }`
err := v.vaultClient.Sys().PutPolicy(v.policyName, rules)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Creating Policy") != nil {
return errors.New("Unable to create policy for approle creation")
}
//Check if applrole is mounted
authMounts, err := v.vaultClient.Sys().ListAuth()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Mount Auth Backend") != nil {
return errors.New("Unable to get mounted auth backends")
}
@@ -440,8 +426,7 @@ func (v *Vault) initRole() error {
// Create a role-id
v.vaultClient.Logical().Write("auth/approle/role/"+rName, data)
sec, err := v.vaultClient.Logical().Read("auth/approle/role/" + rName + "/role-id")
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create RoleID") != nil {
return errors.New("Unable to create role ID for approle")
}
v.roleID = sec.Data["role_id"].(string)
@@ -449,8 +434,7 @@ func (v *Vault) initRole() error {
// Create a secret-id to go with it
sec, err = v.vaultClient.Logical().Write("auth/approle/role/"+rName+"/secret-id",
map[string]interface{}{})
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create SecretID") != nil {
return errors.New("Unable to create secret ID for role")
}
@@ -462,8 +446,7 @@ func (v *Vault) initRole() error {
* using the unseal shards.
*/
err = v.vaultClient.Auth().Token().RevokeSelf(v.vaultToken)
- if err != nil {
- smslogger.WriteWarn(err.Error())
+ if smslogger.CheckError(err, "Revoke Root Token") != nil {
smslogger.WriteWarn("Unable to Revoke Token")
} else {
// Revoked successfully and clear it
@@ -481,6 +464,7 @@ func (v *Vault) initRole() error {
// Function checkToken() gets called multiple times to create
// temporary tokens
func (v *Vault) checkToken() error {
+
v.Lock()
defer v.Unlock()
@@ -501,8 +485,7 @@ func (v *Vault) checkToken() error {
// Create a temporary token using our roleID and secretID
out, err := v.vaultClient.Logical().Write("auth/approle/login",
map[string]interface{}{"role_id": v.roleID, "secret_id": v.secretID})
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create Temp Token") != nil {
return errors.New("Unable to create Temporary Token for Role")
}
@@ -516,11 +499,12 @@ func (v *Vault) checkToken() error {
// vaultInit() is used to initialize the vault in cases where it is not
// initialized. This happens once during intial bring up.
func (v *Vault) initializeVault() error {
+
// Check for vault init status and don't exit till it is initialized
for {
init, err := v.vaultClient.Sys().InitStatus()
- if err != nil {
- smslogger.WriteError("Unable to get initStatus, trying again in 10s: " + err.Error())
+ if smslogger.CheckError(err, "Get Vault Init Status") != nil {
+ smslogger.WriteInfo("Trying again in 10s...")
time.Sleep(time.Second * 10)
continue
}
@@ -545,7 +529,7 @@ func (v *Vault) initializeVault() error {
pbkey, prkey, err := smsauth.GeneratePGPKeyPair()
- if err != nil {
+ if smslogger.CheckError(err, "Generating PGP Keys") != nil {
smslogger.WriteError("Error Generating PGP Keys. Vault Init will not use encryption!")
} else {
initReq.PGPKeys = []string{pbkey, pbkey, pbkey}
@@ -553,8 +537,7 @@ func (v *Vault) initializeVault() error {
}
resp, err := v.vaultClient.Sys().Init(initReq)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Initialize Vault") != nil {
return errors.New("FATAL: Unable to initialize Vault")
}
diff --git a/sms-service/src/sms/backend/vault_test.go b/sms-service/src/sms/backend/vault_test.go
index 484c395..4862665 100644
--- a/sms-service/src/sms/backend/vault_test.go
+++ b/sms-service/src/sms/backend/vault_test.go
@@ -17,9 +17,11 @@
package backend
import (
+ vaultapi "github.com/hashicorp/vault/api"
credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
vaulthttp "github.com/hashicorp/vault/http"
vaultlogical "github.com/hashicorp/vault/logical"
+ vaultinmem "github.com/hashicorp/vault/physical/inmem"
vaulttesting "github.com/hashicorp/vault/vault"
"reflect"
smslog "sms/log"
@@ -229,3 +231,39 @@ func TestDeleteSecret(t *testing.T) {
t.Fatal("DeleteSecret: Error Creating secret")
}
}
+
+func TestInitializeVault(t *testing.T) {
+
+ inm, err := vaultinmem.NewInmem(nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ core, err := vaulttesting.NewCore(&vaulttesting.CoreConfig{
+ DisableMlock: true,
+ DisableCache: true,
+ Physical: inm,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ln, addr := vaulthttp.TestServer(t, core)
+ defer ln.Close()
+
+ client, err := vaultapi.NewClient(&vaultapi.Config{
+ Address: addr,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ v := &Vault{}
+ v.initVaultClient()
+ v.vaultClient = client
+
+ err = v.initializeVault()
+ if err != nil {
+ t.Fatal("InitializeVault: Error initializing Vault")
+ }
+}
diff --git a/sms-service/src/sms/coverage.md b/sms-service/src/sms/coverage.md
deleted file mode 100644
index 6168342..0000000
--- a/sms-service/src/sms/coverage.md
+++ /dev/null
@@ -1,41 +0,0 @@
-## Code Coverage Reports for Golang Applications ##
-
-This document covers how to generate HTML Code Coverage Reports for
-Golang Applications.
-
-#### Generate a test executable which calls your main()
-
-```sh
-$ go test -c -covermode=count -coverpkg ./...
-```
-
-#### Run the generated application to produce a new coverage report
-
-```sh
-$ ./sms.test -test.run "^TestMain$" -test.coverprofile=coverage.cov
-```
-
-#### Run your unit tests to produce their coverage report
-
-```sh
-$ go test -test.covermode=count -test.coverprofile=unit.out ./...
-```
-
-#### Merge the two coverage Reports
-
-```sh
-$ go get github.com/wadey/gocovmerge
-$ gocovmerge unit.out coverage.cov > all.out
-```
-
-#### Generate HTML Report
-
-```sh
-$ go tool cover -html all.out -o coverage.html
-```
-
-#### Generate Function Report
-
-```sh
-$ go tool cover -func all.out
-``` \ No newline at end of file
diff --git a/sms-service/src/sms/handler/handler.go b/sms-service/src/sms/handler/handler.go
index dbf3f93..7ce9e01 100644
--- a/sms-service/src/sms/handler/handler.go
+++ b/sms-service/src/sms/handler/handler.go
@@ -37,15 +37,13 @@ func (h handler) createSecretDomainHandler(w http.ResponseWriter, r *http.Reques
var d smsbackend.SecretDomain
err := json.NewDecoder(r.Body).Decode(&d)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
dom, err := h.secretBackend.CreateSecretDomain(d.Name)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -53,8 +51,7 @@ func (h handler) createSecretDomainHandler(w http.ResponseWriter, r *http.Reques
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
err = json.NewEncoder(w).Encode(dom)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -66,8 +63,7 @@ func (h handler) deleteSecretDomainHandler(w http.ResponseWriter, r *http.Reques
domName := vars["domName"]
err := h.secretBackend.DeleteSecretDomain(domName)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "DeleteSecretDomainHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -84,15 +80,13 @@ func (h handler) createSecretHandler(w http.ResponseWriter, r *http.Request) {
// Get secrets to be stored from body
var b smsbackend.Secret
err := json.NewDecoder(r.Body).Decode(&b)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "CreateSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
err = h.secretBackend.CreateSecret(domName, b)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "CreateSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -107,16 +101,14 @@ func (h handler) getSecretHandler(w http.ResponseWriter, r *http.Request) {
secName := vars["secretName"]
sec, err := h.secretBackend.GetSecret(domName, secName)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "GetSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(sec)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "GetSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -128,8 +120,7 @@ func (h handler) listSecretHandler(w http.ResponseWriter, r *http.Request) {
domName := vars["domName"]
secList, err := h.secretBackend.ListSecret(domName)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "ListSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -143,8 +134,7 @@ func (h handler) listSecretHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(retStruct)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "ListSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -157,8 +147,7 @@ func (h handler) deleteSecretHandler(w http.ResponseWriter, r *http.Request) {
secName := vars["secretName"]
err := h.secretBackend.DeleteSecret(domName, secName)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "DeleteSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -169,8 +158,7 @@ func (h handler) deleteSecretHandler(w http.ResponseWriter, r *http.Request) {
// statusHandler returns information related to SMS and SMS backend services
func (h handler) statusHandler(w http.ResponseWriter, r *http.Request) {
s, err := h.secretBackend.GetStatus()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "StatusHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -183,8 +171,7 @@ func (h handler) statusHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(status)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "StatusHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -207,15 +194,13 @@ func (h handler) unsealHandler(w http.ResponseWriter, r *http.Request) {
decoder := json.NewDecoder(r.Body)
decoder.DisallowUnknownFields()
err := decoder.Decode(&inp)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "UnsealHandler") != nil {
http.Error(w, "Bad input JSON", http.StatusBadRequest)
return
}
err = h.secretBackend.Unseal(inp.UnsealShard)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "UnsealHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -235,15 +220,13 @@ func (h handler) registerHandler(w http.ResponseWriter, r *http.Request) {
decoder := json.NewDecoder(r.Body)
decoder.DisallowUnknownFields()
err := decoder.Decode(&inp)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "RegisterHandler") != nil {
http.Error(w, "Bad input JSON", http.StatusBadRequest)
return
}
sh, err := h.secretBackend.RegisterQuorum(inp.PGPKey)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "RegisterHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -257,8 +240,7 @@ func (h handler) registerHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(shStruct)
- if err != nil {
- smslogger.WriteError("Unable to encode response: " + err.Error())
+ if smslogger.CheckError(err, "RegisterHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
diff --git a/sms-service/src/sms/handler/handler_test.go b/sms-service/src/sms/handler/handler_test.go
index 52637f3..c1e55ed 100644
--- a/sms-service/src/sms/handler/handler_test.go
+++ b/sms-service/src/sms/handler/handler_test.go
@@ -48,7 +48,7 @@ func (b *TestBackend) Unseal(shard string) error {
}
func (b *TestBackend) RegisterQuorum(pgpkey string) (string, error) {
- return "", nil
+ return "N8z4eD2Zgv0eDJrgkkUq3Lh5n2p6Y1Zsui1NIHePlLU=", nil
}
func (b *TestBackend) GetSecret(dom string, sec string) (smsbackend.Secret, error) {
@@ -127,8 +127,49 @@ func TestStatusHandler(t *testing.T) {
}
}
+func TestRegisterHandler(t *testing.T) {
+ body := `{
+ "pgpkey":"asdasdasdasdgkjgljoiwera",
+ "quorumid":"123e4567-e89b-12d3-a456-426655440000"
+ }`
+ reader := strings.NewReader(body)
+ req, err := http.NewRequest("POST", "/v1/sms/quorum/register", reader)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ rr := httptest.NewRecorder()
+ hr := http.HandlerFunc(h.registerHandler)
+
+ hr.ServeHTTP(rr, req)
+
+ ret := rr.Code
+ if ret != http.StatusOK {
+ t.Errorf("registerHandler returned wrong status code: %v vs %v",
+ ret, http.StatusOK)
+ }
+
+ expected := struct {
+ Shard string `json:"shard"`
+ }{
+ "N8z4eD2Zgv0eDJrgkkUq3Lh5n2p6Y1Zsui1NIHePlLU=",
+ }
+ got := struct {
+ Shard string `json:"shard"`
+ }{}
+
+ json.NewDecoder(rr.Body).Decode(&got)
+
+ if reflect.DeepEqual(expected, got) == false {
+ t.Errorf("statusHandler returned unexpected body: got %v vs %v",
+ rr.Body.String(), expected)
+ }
+}
+
func TestUnsealHandler(t *testing.T) {
- req, err := http.NewRequest("GET", "/v1/sms/quorum/unseal", nil)
+ body := `{"unsealshard":"N8z4eD2Zgv0eDJrgkkUq3Lh5n2p6Y1Zsui1NIHePlLU="}`
+ reader := strings.NewReader(body)
+ req, err := http.NewRequest("POST", "/v1/sms/quorum/unseal", reader)
if err != nil {
t.Fatal(err)
}
diff --git a/sms-service/src/sms/log/logger.go b/sms-service/src/sms/log/logger.go
index 25da593..660f1ce 100644
--- a/sms-service/src/sms/log/logger.go
+++ b/sms-service/src/sms/log/logger.go
@@ -17,57 +17,82 @@
package log
import (
+ "fmt"
"log"
"os"
)
-var errLogger *log.Logger
-var warnLogger *log.Logger
-var infoLogger *log.Logger
+var errL, warnL, infoL *log.Logger
+var stdErr, stdWarn, stdInfo *log.Logger
// Init will be called by sms.go before any other packages use it
func Init(filePath string) {
+
+ stdErr = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags)
+ stdWarn = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags)
+ stdInfo = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags)
+
if filePath == "" {
- errLogger = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags)
- warnLogger = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags)
- infoLogger = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags)
+ // We will just to std streams
return
}
f, err := os.Create(filePath)
if err != nil {
- log.Println("Unable to create a log file")
- log.Println(err)
- errLogger = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags)
- warnLogger = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags)
- infoLogger = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags)
- } else {
- errLogger = log.New(f, "ERROR: ", log.Lshortfile|log.LstdFlags)
- warnLogger = log.New(f, "WARNING: ", log.Lshortfile|log.LstdFlags)
- infoLogger = log.New(f, "INFO: ", log.Lshortfile|log.LstdFlags)
+ stdErr.Println("Unable to create log file: " + err.Error())
+ return
}
+
+ errL = log.New(f, "ERROR: ", log.Lshortfile|log.LstdFlags)
+ warnL = log.New(f, "WARNING: ", log.Lshortfile|log.LstdFlags)
+ infoL = log.New(f, "INFO: ", log.Lshortfile|log.LstdFlags)
}
// WriteError writes output to the writer we have
-// defined durint its creation with ERROR prefix
+// defined during its creation with ERROR prefix
func WriteError(msg string) {
- if errLogger != nil {
- errLogger.Println(msg)
+ if errL != nil {
+ errL.Output(2, fmt.Sprintln(msg))
+ }
+ if stdErr != nil {
+ stdErr.Output(2, fmt.Sprintln(msg))
}
}
// WriteWarn writes output to the writer we have
-// defined durint its creation with WARNING prefix
+// defined during its creation with WARNING prefix
func WriteWarn(msg string) {
- if warnLogger != nil {
- warnLogger.Println(msg)
+ if warnL != nil {
+ warnL.Output(2, fmt.Sprintln(msg))
+ }
+ if stdWarn != nil {
+ stdWarn.Output(2, fmt.Sprintln(msg))
}
}
// WriteInfo writes output to the writer we have
-// defined durint its creation with INFO prefix
+// defined during its creation with INFO prefix
func WriteInfo(msg string) {
- if infoLogger != nil {
- infoLogger.Println(msg)
+ if infoL != nil {
+ infoL.Output(2, fmt.Sprintln(msg))
+ }
+ if stdInfo != nil {
+ stdInfo.Output(2, fmt.Sprintln(msg))
+ }
+}
+
+//CheckError is a helper function to reduce
+//repitition of error checkign blocks of code
+func CheckError(err error, topic string) error {
+ if err != nil {
+ msg := topic + ": " + err.Error()
+ if errL != nil {
+ errL.Output(2, fmt.Sprintln(msg))
+ }
+ if stdErr != nil {
+ stdErr.Output(2, fmt.Sprintln(msg))
+ }
+ return err
}
+ return nil
}