summaryrefslogtreecommitdiffstats
path: root/sms-service
diff options
context:
space:
mode:
Diffstat (limited to 'sms-service')
-rw-r--r--sms-service/doc/coverage.html861
-rw-r--r--sms-service/doc/coverage.md (renamed from sms-service/src/sms/coverage.md)0
-rw-r--r--sms-service/src/quorumclient/quorumclient.go22
-rw-r--r--sms-service/src/sms/Gopkg.lock2
-rw-r--r--sms-service/src/sms/auth/auth.go82
-rw-r--r--sms-service/src/sms/backend/backend.go3
-rw-r--r--sms-service/src/sms/backend/vault.go105
-rw-r--r--sms-service/src/sms/backend/vault_test.go38
-rw-r--r--sms-service/src/sms/handler/handler.go54
-rw-r--r--sms-service/src/sms/handler/handler_test.go45
-rw-r--r--sms-service/src/sms/log/logger.go73
11 files changed, 809 insertions, 476 deletions
diff --git a/sms-service/doc/coverage.html b/sms-service/doc/coverage.html
index d03ddde..39ee191 100644
--- a/sms-service/doc/coverage.html
+++ b/sms-service/doc/coverage.html
@@ -54,19 +54,19 @@
<div id="nav">
<select id="files">
- <option value="file0">sms/auth/auth.go (17.6%)</option>
+ <option value="file0">sms/auth/auth.go (76.1%)</option>
- <option value="file1">sms/backend/backend.go (66.7%)</option>
+ <option value="file1">sms/backend/backend.go (80.0%)</option>
- <option value="file2">sms/backend/vault.go (60.5%)</option>
+ <option value="file2">sms/backend/vault.go (72.5%)</option>
- <option value="file3">sms/config/config.go (90.9%)</option>
+ <option value="file3">sms/config/config.go (78.6%)</option>
- <option value="file4">sms/handler/handler.go (55.1%)</option>
+ <option value="file4">sms/handler/handler.go (63.0%)</option>
- <option value="file5">sms/log/logger.go (31.2%)</option>
+ <option value="file5">sms/log/logger.go (65.6%)</option>
- <option value="file6">sms/sms.go (82.6%)</option>
+ <option value="file6">sms/sms.go (77.8%)</option>
</select>
</div>
@@ -109,6 +109,7 @@ package auth
import (
"bytes"
+ "crypto"
"crypto/tls"
"crypto/x509"
"encoding/base64"
@@ -119,63 +120,63 @@ import (
smslogger "sms/log"
)
-var tlsConfig *tls.Config
-
// GetTLSConfig initializes a tlsConfig using the CA's certificate
// This config is then used to enable the server for mutual TLS
func GetTLSConfig(caCertFile string) (*tls.Config, error) <span class="cov10" title="3">{
+
// Initialize tlsConfig once
- if tlsConfig == nil </span><span class="cov10" title="3">{
- caCert, err := ioutil.ReadFile(caCertFile)
+ caCert, err := ioutil.ReadFile(caCertFile)
- if err != nil </span><span class="cov1" title="1">{
- return nil, err
- }</span>
+ if err != nil </span><span class="cov1" title="1">{
+ return nil, err
+ }</span>
- <span class="cov6" title="2">caCertPool := x509.NewCertPool()
- caCertPool.AppendCertsFromPEM(caCert)
+ <span class="cov6" title="2">caCertPool := x509.NewCertPool()
+ caCertPool.AppendCertsFromPEM(caCert)
- tlsConfig = &amp;tls.Config{
- ClientAuth: tls.RequireAndVerifyClientCert,
- ClientCAs: caCertPool,
- MinVersion: tls.VersionTLS12,
- }
- tlsConfig.BuildNameToCertificate()</span>
+ tlsConfig := &amp;tls.Config{
+ // Change to RequireAndVerify once we have mandatory certs
+ ClientAuth: tls.VerifyClientCertIfGiven,
+ ClientCAs: caCertPool,
+ MinVersion: tls.VersionTLS12,
}
- <span class="cov6" title="2">return tlsConfig, nil</span>
+ tlsConfig.BuildNameToCertificate()
+ return tlsConfig, nil</span>
}
// GeneratePGPKeyPair produces a PGP key pair and returns
// two things:
// A base64 encoded form of the public part of the entity
// A base64 encoded form of the private key
-func GeneratePGPKeyPair() (string, string, error) <span class="cov0" title="0">{
+func GeneratePGPKeyPair() (string, string, error) <span class="cov10" title="3">{
+
var entity *openpgp.Entity
- entity, err := openpgp.NewEntity("aaf.sms.init", "PGP Key for unsealing", "", nil)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ config := &amp;packet.Config{
+ DefaultHash: crypto.SHA256,
+ }
+
+ entity, err := openpgp.NewEntity("aaf.sms.init", "PGP Key for unsealing", "", config)
+ if smslogger.CheckError(err, "Create Entity") != nil </span><span class="cov0" title="0">{
return "", "", err
}</span>
// Sign the identity in the entity
- <span class="cov0" title="0">for _, id := range entity.Identities </span><span class="cov0" title="0">{
+ <span class="cov10" title="3">for _, id := range entity.Identities </span><span class="cov10" title="3">{
err = id.SelfSignature.SignUserId(id.UserId.Id, entity.PrimaryKey, entity.PrivateKey, nil)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Sign Entity") != nil </span><span class="cov0" title="0">{
return "", "", err
}</span>
}
// Sign the subkey in the entity
- <span class="cov0" title="0">for _, subkey := range entity.Subkeys </span><span class="cov0" title="0">{
+ <span class="cov10" title="3">for _, subkey := range entity.Subkeys </span><span class="cov10" title="3">{
err := subkey.Sig.SignKey(subkey.PublicKey, entity.PrivateKey, nil)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Sign Subkey") != nil </span><span class="cov0" title="0">{
return "", "", err
}</span>
}
- <span class="cov0" title="0">buffer := new(bytes.Buffer)
+ <span class="cov10" title="3">buffer := new(bytes.Buffer)
entity.Serialize(buffer)
pbkey := base64.StdEncoding.EncodeToString(buffer.Bytes())
@@ -186,40 +187,96 @@ func GeneratePGPKeyPair() (string, string, error) <span class="cov0" title="0">{
return pbkey, prkey, nil</span>
}
-// DecryptPGPBytes decrypts a PGP encoded input string and returns
+// EncryptPGPString takes data and a public key and encrypts using that
+// public key
+func EncryptPGPString(data string, pbKey string) (string, error) <span class="cov6" title="2">{
+
+ pbKeyBytes, err := base64.StdEncoding.DecodeString(pbKey)
+ if smslogger.CheckError(err, "Decoding Base64 Public Key") != nil </span><span class="cov0" title="0">{
+ return "", err
+ }</span>
+
+ <span class="cov6" title="2">dataBytes := []byte(data)
+
+ pbEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(pbKeyBytes)))
+ if smslogger.CheckError(err, "Reading entity from PGP key") != nil </span><span class="cov0" title="0">{
+ return "", err
+ }</span>
+
+ // encrypt string
+ <span class="cov6" title="2">buf := new(bytes.Buffer)
+ out, err := openpgp.Encrypt(buf, []*openpgp.Entity{pbEntity}, nil, nil, nil)
+ if smslogger.CheckError(err, "Creating Encryption Pipe") != nil </span><span class="cov0" title="0">{
+ return "", err
+ }</span>
+
+ <span class="cov6" title="2">_, err = out.Write(dataBytes)
+ if smslogger.CheckError(err, "Writing to Encryption Pipe") != nil </span><span class="cov0" title="0">{
+ return "", err
+ }</span>
+
+ <span class="cov6" title="2">err = out.Close()
+ if smslogger.CheckError(err, "Closing Encryption Pipe") != nil </span><span class="cov0" title="0">{
+ return "", err
+ }</span>
+
+ <span class="cov6" title="2">crp := base64.StdEncoding.EncodeToString(buf.Bytes())
+ return crp, nil</span>
+}
+
+// DecryptPGPString decrypts a PGP encoded input string and returns
// a base64 representation of the decoded string
-func DecryptPGPBytes(data string, prKey string) (string, error) <span class="cov0" title="0">{
+func DecryptPGPString(data string, prKey string) (string, error) <span class="cov1" title="1">{
+
// Convert private key to bytes from base64
prKeyBytes, err := base64.StdEncoding.DecodeString(prKey)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError("Error Decoding base64 private key: " + err.Error())
+ if smslogger.CheckError(err, "Decoding Base64 Private Key") != nil </span><span class="cov0" title="0">{
return "", err
}</span>
- <span class="cov0" title="0">dataBytes, err := base64.StdEncoding.DecodeString(data)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError("Error Decoding base64 data: " + err.Error())
+ <span class="cov1" title="1">dataBytes, err := base64.StdEncoding.DecodeString(data)
+ if smslogger.CheckError(err, "Decoding base64 data") != nil </span><span class="cov0" title="0">{
return "", err
}</span>
- <span class="cov0" title="0">prEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(prKeyBytes)))
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError("Error reading entity from PGP key: " + err.Error())
+ <span class="cov1" title="1">prEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(prKeyBytes)))
+ if smslogger.CheckError(err, "Read Entity") != nil </span><span class="cov0" title="0">{
return "", err
}</span>
- <span class="cov0" title="0">prEntityList := &amp;openpgp.EntityList{prEntity}
+ <span class="cov1" title="1">prEntityList := &amp;openpgp.EntityList{prEntity}
message, err := openpgp.ReadMessage(bytes.NewBuffer(dataBytes), prEntityList, nil, nil)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError("Error Decrypting message: " + err.Error())
+ if smslogger.CheckError(err, "Decrypting Message") != nil </span><span class="cov0" title="0">{
return "", err
}</span>
- <span class="cov0" title="0">var retBuf bytes.Buffer
+ <span class="cov1" title="1">var retBuf bytes.Buffer
retBuf.ReadFrom(message.UnverifiedBody)
return retBuf.String(), nil</span>
}
+
+// ReadFromFile reads a file and loads the PGP key into
+// a string
+func ReadFromFile(fileName string) (string, error) <span class="cov6" title="2">{
+
+ data, err := ioutil.ReadFile(fileName)
+ if smslogger.CheckError(err, "Read from file") != nil </span><span class="cov0" title="0">{
+ return "", err
+ }</span>
+ <span class="cov6" title="2">return string(data), nil</span>
+}
+
+// WriteToFile writes a PGP key into a file.
+// It will truncate the file if it exists
+func WriteToFile(data string, fileName string) error <span class="cov0" title="0">{
+
+ err := ioutil.WriteFile(fileName, []byte(data), 0600)
+ if smslogger.CheckError(err, "Write to file") != nil </span><span class="cov0" title="0">{
+ return err
+ }</span>
+ <span class="cov0" title="0">return nil</span>
+}
</pre>
<pre class="file" id="file1" style="display: none">/*
@@ -264,6 +321,7 @@ type SecretBackend interface {
Init() error
GetStatus() (bool, error)
Unseal(shard string) error
+ RegisterQuorum(pgpkey string) (string, error)
GetSecret(dom string, sec string) (Secret, error)
ListSecret(dom string) ([]string, error)
@@ -276,19 +334,18 @@ type SecretBackend interface {
}
// InitSecretBackend returns an interface implementation
-func InitSecretBackend() (SecretBackend, error) <span class="cov10" title="2">{
+func InitSecretBackend() (SecretBackend, error) <span class="cov8" title="1">{
backendImpl := &amp;Vault{
- vaultAddress: smsconfig.SMSConfig.VaultAddress,
+ vaultAddress: smsconfig.SMSConfig.BackendAddress,
vaultToken: smsconfig.SMSConfig.VaultToken,
}
err := backendImpl.Init()
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "InitSecretBackend") != nil </span><span class="cov0" title="0">{
return nil, err
}</span>
- <span class="cov10" title="2">return backendImpl, nil</span>
+ <span class="cov8" title="1">return backendImpl, nil</span>
}
// LoginBackend Interface that will be implemented for various login backends
@@ -330,68 +387,108 @@ import (
// Vault is the main Struct used in Backend to initialize the struct
type Vault struct {
sync.Mutex
- engineType string
- initRoleDone bool
- policyName string
- roleID string
- secretID string
- vaultAddress string
- vaultClient *vaultapi.Client
- vaultMount string
- vaultTempTokenTTL time.Time
- vaultToken string
- unsealShards []string
- rootToken string
- pgpPub string
- pgpPr string
+ initRoleDone bool
+ policyName string
+ roleID string
+ secretID string
+ vaultAddress string
+ vaultClient *vaultapi.Client
+ vaultMountPrefix string
+ internalDomain string
+ internalDomainMounted bool
+ vaultTempTokenTTL time.Time
+ vaultToken string
+ shards []string
+ prkey string
}
-// Init will initialize the vault connection
-// It will also create the initial policy if it does not exist
-// TODO: Check to see if we need to wait for vault to be running
-func (v *Vault) Init() error <span class="cov4" title="3">{
+// initVaultClient will create the initial
+// Vault strcuture and populate it with the
+// right values and it will also create
+// a vault client
+func (v *Vault) initVaultClient() error <span class="cov6" title="11">{
+
vaultCFG := vaultapi.DefaultConfig()
vaultCFG.Address = v.vaultAddress
client, err := vaultapi.NewClient(vaultCFG)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
- return errors.New("Unable to create new vault client")
+ if smslogger.CheckError(err, "Create new vault client") != nil </span><span class="cov0" title="0">{
+ return err
}</span>
- <span class="cov4" title="3">v.engineType = "kv"
- v.initRoleDone = false
+ <span class="cov6" title="11">v.initRoleDone = false
v.policyName = "smsvaultpolicy"
v.vaultClient = client
- v.vaultMount = "sms"
+ v.vaultMountPrefix = "sms"
+ v.internalDomain = "smsinternaldomain"
+ v.internalDomainMounted = false
+ v.prkey = ""
+ return nil</span>
+}
- err = v.initRole()
- if err != nil </span><span class="cov2" title="2">{
- smslogger.WriteError(err.Error())
+// Init will initialize the vault connection
+// It will also initialize vault if it is not
+// already initialized.
+// The initial policy will also be created
+func (v *Vault) Init() error <span class="cov1" title="1">{
+
+ v.initVaultClient()
+ // Initialize vault if it is not already
+ // Returns immediately if it is initialized
+ v.initializeVault()
+
+ err := v.initRole()
+ if smslogger.CheckError(err, "InitRole First Attempt") != nil </span><span class="cov0" title="0">{
smslogger.WriteInfo("InitRole will try again later")
}</span>
- <span class="cov4" title="3">return nil</span>
+ <span class="cov1" title="1">return nil</span>
}
// GetStatus returns the current seal status of vault
-func (v *Vault) GetStatus() (bool, error) <span class="cov4" title="3">{
+func (v *Vault) GetStatus() (bool, error) <span class="cov3" title="3">{
+
sys := v.vaultClient.Sys()
sealStatus, err := sys.SealStatus()
- if err != nil </span><span class="cov1" title="1">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Getting Status") != nil </span><span class="cov0" title="0">{
return false, errors.New("Error getting status")
}</span>
- <span class="cov2" title="2">return sealStatus.Sealed, nil</span>
+ <span class="cov3" title="3">return sealStatus.Sealed, nil</span>
+}
+
+// RegisterQuorum registers the PGP public key for a quorum client
+// We will return a shard to the client that is registering
+func (v *Vault) RegisterQuorum(pgpkey string) (string, error) <span class="cov0" title="0">{
+
+ v.Lock()
+ defer v.Unlock()
+
+ if v.shards == nil </span><span class="cov0" title="0">{
+ smslogger.WriteError("Invalid operation in RegisterQuorum")
+ return "", errors.New("Invalid operation")
+ }</span>
+ // Pop the slice
+ <span class="cov0" title="0">var sh string
+ sh, v.shards = v.shards[len(v.shards)-1], v.shards[:len(v.shards)-1]
+ if len(v.shards) == 0 </span><span class="cov0" title="0">{
+ v.shards = nil
+ }</span>
+
+ // Decrypt with SMS pgp Key
+ <span class="cov0" title="0">sh, _ = smsauth.DecryptPGPString(sh, v.prkey)
+ // Encrypt with Quorum client pgp key
+ sh, _ = smsauth.EncryptPGPString(sh, pgpkey)
+
+ return sh, nil</span>
}
// Unseal is a passthrough API that allows any
// unseal or initialization processes for the backend
func (v *Vault) Unseal(shard string) error <span class="cov0" title="0">{
+
sys := v.vaultClient.Sys()
_, err := sys.Unseal(shard)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Unseal Operation") != nil </span><span class="cov0" title="0">{
return errors.New("Unable to execute unseal operation with specified shard")
}</span>
@@ -401,80 +498,140 @@ func (v *Vault) Unseal(shard string) error <span class="cov0" title="0">{
// GetSecret returns a secret mounted on a particular domain name
// The secret itself is referenced via its name which translates to
// a mount path in vault
-func (v *Vault) GetSecret(dom string, name string) (Secret, error) <span class="cov6" title="6">{
+func (v *Vault) GetSecret(dom string, name string) (Secret, error) <span class="cov5" title="7">{
+
err := v.checkToken()
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Tocken Check") != nil </span><span class="cov0" title="0">{
return Secret{}, errors.New("Token check failed")
}</span>
- <span class="cov6" title="6">dom = v.vaultMount + "/" + dom
+ <span class="cov5" title="7">dom = v.vaultMountPrefix + "/" + dom
sec, err := v.vaultClient.Logical().Read(dom + "/" + name)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Read Secret") != nil </span><span class="cov0" title="0">{
return Secret{}, errors.New("Unable to read Secret at provided path")
}</span>
// sec and err are nil in the case where a path does not exist
- <span class="cov6" title="6">if sec == nil </span><span class="cov0" title="0">{
+ <span class="cov5" title="7">if sec == nil </span><span class="cov0" title="0">{
smslogger.WriteWarn("Vault read was empty. Invalid Path")
return Secret{}, errors.New("Secret not found at the provided path")
}</span>
- <span class="cov6" title="6">return Secret{Name: name, Values: sec.Data}, nil</span>
+ <span class="cov5" title="7">return Secret{Name: name, Values: sec.Data}, nil</span>
}
// ListSecret returns a list of secret names on a particular domain
// The values of the secret are not returned
-func (v *Vault) ListSecret(dom string) ([]string, error) <span class="cov2" title="2">{
+func (v *Vault) ListSecret(dom string) ([]string, error) <span class="cov3" title="3">{
+
err := v.checkToken()
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil </span><span class="cov0" title="0">{
return nil, errors.New("Token check failed")
}</span>
- <span class="cov2" title="2">dom = v.vaultMount + "/" + dom
+ <span class="cov3" title="3">dom = v.vaultMountPrefix + "/" + dom
sec, err := v.vaultClient.Logical().List(dom)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Read Secret") != nil </span><span class="cov0" title="0">{
return nil, errors.New("Unable to read Secret at provided path")
}</span>
// sec and err are nil in the case where a path does not exist
- <span class="cov2" title="2">if sec == nil </span><span class="cov0" title="0">{
+ <span class="cov3" title="3">if sec == nil </span><span class="cov0" title="0">{
smslogger.WriteWarn("Vaultclient returned empty data")
return nil, errors.New("Secret not found at the provided path")
}</span>
- <span class="cov2" title="2">val, ok := sec.Data["keys"].([]interface{})
+ <span class="cov3" title="3">val, ok := sec.Data["keys"].([]interface{})
if !ok </span><span class="cov0" title="0">{
smslogger.WriteError("Secret not found at the provided path")
return nil, errors.New("Secret not found at the provided path")
}</span>
- <span class="cov2" title="2">retval := make([]string, len(val))
- for i, v := range val </span><span class="cov6" title="6">{
+ <span class="cov3" title="3">retval := make([]string, len(val))
+ for i, v := range val </span><span class="cov5" title="7">{
retval[i] = fmt.Sprint(v)
}</span>
- <span class="cov2" title="2">return retval, nil</span>
+ <span class="cov3" title="3">return retval, nil</span>
+}
+
+// Mounts the internal Domain if its not already mounted
+func (v *Vault) mountInternalDomain(name string) error <span class="cov5" title="8">{
+
+ if v.internalDomainMounted </span><span class="cov1" title="1">{
+ return nil
+ }</span>
+
+ <span class="cov5" title="7">name = strings.TrimSpace(name)
+ mountPath := v.vaultMountPrefix + "/" + name
+ mountInput := &amp;vaultapi.MountInput{
+ Type: "kv",
+ Description: "Mount point for domain: " + name,
+ Local: false,
+ SealWrap: false,
+ Config: vaultapi.MountConfigInput{},
+ }
+
+ err := v.vaultClient.Sys().Mount(mountPath, mountInput)
+ if smslogger.CheckError(err, "Mount internal Domain") != nil </span><span class="cov1" title="1">{
+ if strings.Contains(err.Error(), "existing mount") </span><span class="cov1" title="1">{
+ // It is already mounted
+ v.internalDomainMounted = true
+ return nil
+ }</span>
+ // Ran into some other error mounting it.
+ <span class="cov0" title="0">return errors.New("Unable to mount internal Domain")</span>
+ }
+
+ <span class="cov5" title="6">v.internalDomainMounted = true
+ return nil</span>
+}
+
+// Stores the UUID created for secretdomain in vault
+// under v.vaultMountPrefix / smsinternal domain
+func (v *Vault) storeUUID(uuid string, name string) error <span class="cov5" title="8">{
+
+ // Check if token is still valid
+ err := v.checkToken()
+ if smslogger.CheckError(err, "Token Check") != nil </span><span class="cov0" title="0">{
+ return errors.New("Token Check failed")
+ }</span>
+
+ <span class="cov5" title="8">err = v.mountInternalDomain(v.internalDomain)
+ if smslogger.CheckError(err, "Mount Internal Domain") != nil </span><span class="cov0" title="0">{
+ return err
+ }</span>
+
+ <span class="cov5" title="8">secret := Secret{
+ Name: name,
+ Values: map[string]interface{}{
+ "uuid": uuid,
+ },
+ }
+
+ err = v.CreateSecret(v.internalDomain, secret)
+ if smslogger.CheckError(err, "Write UUID to domain") != nil </span><span class="cov0" title="0">{
+ return err
+ }</span>
+
+ <span class="cov5" title="8">return nil</span>
}
// CreateSecretDomain mounts the kv backend on a path with the given name
-func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) <span class="cov2" title="2">{
+func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) <span class="cov5" title="8">{
+
// Check if token is still valid
err := v.checkToken()
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil </span><span class="cov0" title="0">{
return SecretDomain{}, errors.New("Token Check failed")
}</span>
- <span class="cov2" title="2">name = strings.TrimSpace(name)
- mountPath := v.vaultMount + "/" + name
+ <span class="cov5" title="8">name = strings.TrimSpace(name)
+ mountPath := v.vaultMountPrefix + "/" + name
mountInput := &amp;vaultapi.MountInput{
- Type: v.engineType,
+ Type: "kv",
Description: "Mount point for domain: " + name,
Local: false,
SealWrap: false,
@@ -482,171 +639,212 @@ func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) <span clas
}
err = v.vaultClient.Sys().Mount(mountPath, mountInput)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create Domain") != nil </span><span class="cov0" title="0">{
return SecretDomain{}, errors.New("Unable to create Secret Domain")
}</span>
- <span class="cov2" title="2">uuid, _ := uuid.GenerateUUID()
- return SecretDomain{uuid, name}, nil</span>
+ <span class="cov5" title="8">uuid, _ := uuid.GenerateUUID()
+ err = v.storeUUID(uuid, name)
+ if smslogger.CheckError(err, "Store UUID") != nil </span><span class="cov0" title="0">{
+ // Mount was successful at this point.
+ // Rollback the mount operation since we could not
+ // store the UUID for the mount.
+ v.vaultClient.Sys().Unmount(mountPath)
+ return SecretDomain{}, errors.New("Unable to store Secret Domain UUID. Retry")
+ }</span>
+
+ <span class="cov5" title="8">return SecretDomain{uuid, name}, nil</span>
}
// CreateSecret creates a secret mounted on a particular domain name
// The secret itself is mounted on a path specified by name
-func (v *Vault) CreateSecret(dom string, sec Secret) error <span class="cov6" title="6">{
+func (v *Vault) CreateSecret(dom string, sec Secret) error <span class="cov7" title="18">{
+
err := v.checkToken()
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil </span><span class="cov0" title="0">{
return errors.New("Token check failed")
}</span>
- <span class="cov6" title="6">dom = v.vaultMount + "/" + dom
+ <span class="cov7" title="18">dom = v.vaultMountPrefix + "/" + dom
// Vault return is empty on successful write
// TODO: Check if values is not empty
_, err = v.vaultClient.Logical().Write(dom+"/"+sec.Name, sec.Values)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create Secret") != nil </span><span class="cov0" title="0">{
return errors.New("Unable to create Secret at provided path")
}</span>
- <span class="cov6" title="6">return nil</span>
+ <span class="cov7" title="18">return nil</span>
}
// DeleteSecretDomain deletes a secret domain which translates to
// an unmount operation on the given path in Vault
-func (v *Vault) DeleteSecretDomain(name string) error <span class="cov2" title="2">{
+func (v *Vault) DeleteSecretDomain(name string) error <span class="cov3" title="3">{
+
err := v.checkToken()
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil </span><span class="cov0" title="0">{
return errors.New("Token Check Failed")
}</span>
- <span class="cov2" title="2">name = strings.TrimSpace(name)
- mountPath := v.vaultMount + "/" + name
+ <span class="cov3" title="3">name = strings.TrimSpace(name)
+ mountPath := v.vaultMountPrefix + "/" + name
err = v.vaultClient.Sys().Unmount(mountPath)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Delete Domain") != nil </span><span class="cov0" title="0">{
return errors.New("Unable to delete domain specified")
}</span>
- <span class="cov2" title="2">return nil</span>
+ <span class="cov3" title="3">return nil</span>
}
// DeleteSecret deletes a secret mounted on the path provided
-func (v *Vault) DeleteSecret(dom string, name string) error <span class="cov6" title="6">{
+func (v *Vault) DeleteSecret(dom string, name string) error <span class="cov5" title="7">{
+
err := v.checkToken()
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil </span><span class="cov0" title="0">{
return errors.New("Token check failed")
}</span>
- <span class="cov6" title="6">dom = v.vaultMount + "/" + dom
+ <span class="cov5" title="7">dom = v.vaultMountPrefix + "/" + dom
// Vault return is empty on successful delete
_, err = v.vaultClient.Logical().Delete(dom + "/" + name)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Delete Secret") != nil </span><span class="cov0" title="0">{
return errors.New("Unable to delete Secret at provided path")
}</span>
- <span class="cov6" title="6">return nil</span>
+ <span class="cov5" title="7">return nil</span>
}
-// initRole is called only once during the service bring up
-func (v *Vault) initRole() error <span class="cov4" title="3">{
+// initRole is called only once during SMS bring up
+// It initially creates a role and secret id associated with
+// that role. Later restarts will use the existing role-id
+// and secret-id stored on disk
+func (v *Vault) initRole() error <span class="cov10" title="56">{
+
+ if v.initRoleDone </span><span class="cov9" title="48">{
+ return nil
+ }</span>
+
// Use the root token once here
- v.vaultClient.SetToken(v.vaultToken)
+ <span class="cov5" title="8">v.vaultClient.SetToken(v.vaultToken)
defer v.vaultClient.ClearToken()
- rules := `path "sms/*" { capabilities = ["create", "read", "update", "delete", "list"] }
+ // Check if roleID and secretID has already been created
+ rID, error := smsauth.ReadFromFile("auth/role")
+ if error != nil </span><span class="cov5" title="7">{
+ smslogger.WriteWarn("Unable to find RoleID. Generating...")
+ }</span><span class="cov1" title="1"> else {
+ sID, error := smsauth.ReadFromFile("auth/secret")
+ if error != nil </span><span class="cov0" title="0">{
+ smslogger.WriteWarn("Unable to find secretID. Generating...")
+ }</span><span class="cov1" title="1"> else {
+ v.roleID = rID
+ v.secretID = sID
+ v.initRoleDone = true
+ return nil
+ }</span>
+ }
+
+ <span class="cov5" title="7">rules := `path "sms/*" { capabilities = ["create", "read", "update", "delete", "list"] }
path "sys/mounts/sms*" { capabilities = ["update","delete","create"] }`
err := v.vaultClient.Sys().PutPolicy(v.policyName, rules)
- if err != nil </span><span class="cov2" title="2">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Creating Policy") != nil </span><span class="cov0" title="0">{
return errors.New("Unable to create policy for approle creation")
}</span>
- <span class="cov1" title="1">rName := v.vaultMount + "-role"
- data := map[string]interface{}{
- "token_ttl": "60m",
- "policies": [2]string{"default", v.policyName},
- }
-
//Check if applrole is mounted
- authMounts, err := v.vaultClient.Sys().ListAuth()
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ <span class="cov5" title="7">authMounts, err := v.vaultClient.Sys().ListAuth()
+ if smslogger.CheckError(err, "Mount Auth Backend") != nil </span><span class="cov0" title="0">{
return errors.New("Unable to get mounted auth backends")
}</span>
- <span class="cov1" title="1">approleMounted := false
- for k, v := range authMounts </span><span class="cov1" title="1">{
- if v.Type == "approle" &amp;&amp; k == "approle/" </span><span class="cov1" title="1">{
+ <span class="cov5" title="7">approleMounted := false
+ for k, v := range authMounts </span><span class="cov5" title="7">{
+ if v.Type == "approle" &amp;&amp; k == "approle/" </span><span class="cov0" title="0">{
approleMounted = true
break</span>
}
}
// Mount approle in case its not already mounted
- <span class="cov1" title="1">if !approleMounted </span><span class="cov0" title="0">{
+ <span class="cov5" title="7">if !approleMounted </span><span class="cov5" title="7">{
v.vaultClient.Sys().EnableAuth("approle", "approle", "")
}</span>
+ <span class="cov5" title="7">rName := v.vaultMountPrefix + "-role"
+ data := map[string]interface{}{
+ "token_ttl": "60m",
+ "policies": [2]string{"default", v.policyName},
+ }
+
// Create a role-id
- <span class="cov1" title="1">v.vaultClient.Logical().Write("auth/approle/role/"+rName, data)
+ v.vaultClient.Logical().Write("auth/approle/role/"+rName, data)
sec, err := v.vaultClient.Logical().Read("auth/approle/role/" + rName + "/role-id")
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create RoleID") != nil </span><span class="cov0" title="0">{
return errors.New("Unable to create role ID for approle")
}</span>
- <span class="cov1" title="1">v.roleID = sec.Data["role_id"].(string)
+ <span class="cov5" title="7">v.roleID = sec.Data["role_id"].(string)
// Create a secret-id to go with it
sec, err = v.vaultClient.Logical().Write("auth/approle/role/"+rName+"/secret-id",
map[string]interface{}{})
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create SecretID") != nil </span><span class="cov0" title="0">{
return errors.New("Unable to create secret ID for role")
}</span>
- <span class="cov1" title="1">v.secretID = sec.Data["secret_id"].(string)
+ <span class="cov5" title="7">v.secretID = sec.Data["secret_id"].(string)
v.initRoleDone = true
+ /*
+ * Revoke the Root token.
+ * If a new Root Token is needed, it will need to be created
+ * using the unseal shards.
+ */
+ err = v.vaultClient.Auth().Token().RevokeSelf(v.vaultToken)
+ if smslogger.CheckError(err, "Revoke Root Token") != nil </span><span class="cov0" title="0">{
+ smslogger.WriteWarn("Unable to Revoke Token")
+ }</span><span class="cov5" title="7"> else {
+ // Revoked successfully and clear it
+ v.vaultToken = ""
+ }</span>
+
+ // Store the role-id and secret-id
+ // We will need this if SMS restarts
+ <span class="cov5" title="7">smsauth.WriteToFile(v.roleID, "auth/role")
+ smsauth.WriteToFile(v.secretID, "auth/secret")
+
return nil</span>
}
// Function checkToken() gets called multiple times to create
// temporary tokens
-func (v *Vault) checkToken() error <span class="cov10" title="24">{
+func (v *Vault) checkToken() error <span class="cov9" title="54">{
+
v.Lock()
defer v.Unlock()
// Init Role if it is not yet done
// Role needs to be created before token can be created
- if v.initRoleDone == false </span><span class="cov0" title="0">{
- err := v.initRole()
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
- return errors.New("Unable to initRole in checkToken")
- }</span>
- }
+ err := v.initRole()
+ if err != nil </span><span class="cov0" title="0">{
+ smslogger.WriteError(err.Error())
+ return errors.New("Unable to initRole in checkToken")
+ }</span>
// Return immediately if token still has life
- <span class="cov10" title="24">if v.vaultClient.Token() != "" &amp;&amp;
- time.Since(v.vaultTempTokenTTL) &lt; time.Minute*50 </span><span class="cov9" title="23">{
+ <span class="cov9" title="54">if v.vaultClient.Token() != "" &amp;&amp;
+ time.Since(v.vaultTempTokenTTL) &lt; time.Minute*50 </span><span class="cov9" title="47">{
return nil
}</span>
// Create a temporary token using our roleID and secretID
- <span class="cov1" title="1">out, err := v.vaultClient.Logical().Write("auth/approle/login",
+ <span class="cov5" title="7">out, err := v.vaultClient.Logical().Write("auth/approle/login",
map[string]interface{}{"role_id": v.roleID, "secret_id": v.secretID})
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create Temp Token") != nil </span><span class="cov0" title="0">{
return errors.New("Unable to create Temporary Token for Role")
}</span>
- <span class="cov1" title="1">tok, err := out.TokenID()
+ <span class="cov5" title="7">tok, err := out.TokenID()
v.vaultTempTokenTTL = time.Now()
v.vaultClient.SetToken(tok)
@@ -655,31 +853,53 @@ func (v *Vault) checkToken() error <span class="cov10" title="24">{
// vaultInit() is used to initialize the vault in cases where it is not
// initialized. This happens once during intial bring up.
-func (v *Vault) initializeVault() error <span class="cov0" title="0">{
- initReq := &amp;vaultapi.InitRequest{
- SecretShares: 5,
+func (v *Vault) initializeVault() error <span class="cov2" title="2">{
+
+ // Check for vault init status and don't exit till it is initialized
+ for </span><span class="cov2" title="2">{
+ init, err := v.vaultClient.Sys().InitStatus()
+ if smslogger.CheckError(err, "Get Vault Init Status") != nil </span><span class="cov0" title="0">{
+ smslogger.WriteInfo("Trying again in 10s...")
+ time.Sleep(time.Second * 10)
+ continue</span>
+ }
+ // Did not get any error
+ <span class="cov2" title="2">if init == true </span><span class="cov1" title="1">{
+ smslogger.WriteInfo("Vault is already Initialized")
+ return nil
+ }</span>
+
+ // init status is false
+ // break out of loop and finish initialization
+ <span class="cov1" title="1">smslogger.WriteInfo("Vault is not initialized. Initializing...")
+ break</span>
+ }
+
+ // Hardcoded this to 3. We should make this configurable
+ // in the future
+ <span class="cov1" title="1">initReq := &amp;vaultapi.InitRequest{
+ SecretShares: 3,
SecretThreshold: 3,
}
pbkey, prkey, err := smsauth.GeneratePGPKeyPair()
- if err != nil </span><span class="cov0" title="0">{
+
+ if smslogger.CheckError(err, "Generating PGP Keys") != nil </span><span class="cov0" title="0">{
smslogger.WriteError("Error Generating PGP Keys. Vault Init will not use encryption!")
- }</span><span class="cov0" title="0"> else {
- initReq.PGPKeys = []string{pbkey, pbkey, pbkey, pbkey, pbkey}
+ }</span><span class="cov1" title="1"> else {
+ initReq.PGPKeys = []string{pbkey, pbkey, pbkey}
initReq.RootTokenPGPKey = pbkey
- v.pgpPub = pbkey
- v.pgpPr = prkey
}</span>
- <span class="cov0" title="0">resp, err := v.vaultClient.Sys().Init(initReq)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ <span class="cov1" title="1">resp, err := v.vaultClient.Sys().Init(initReq)
+ if smslogger.CheckError(err, "Initialize Vault") != nil </span><span class="cov0" title="0">{
return errors.New("FATAL: Unable to initialize Vault")
}</span>
- <span class="cov0" title="0">if resp != nil </span><span class="cov0" title="0">{
- v.unsealShards = resp.KeysB64
- v.rootToken = resp.RootToken
+ <span class="cov1" title="1">if resp != nil </span><span class="cov1" title="1">{
+ v.prkey = prkey
+ v.shards = resp.KeysB64
+ v.vaultToken, _ = smsauth.DecryptPGPString(resp.RootToken, prkey)
return nil
}</span>
@@ -708,6 +928,7 @@ package config
import (
"encoding/json"
"os"
+ smslogger "sms/log"
)
// SMSConfiguration loads up all the values that are used to configure
@@ -718,8 +939,10 @@ type SMSConfiguration struct {
ServerCert string `json:"servercert"`
ServerKey string `json:"serverkey"`
- VaultAddress string `json:"vaultaddress"`
- VaultToken string `json:"vaulttoken"`
+ BackendAddress string `json:"smsdbaddress"`
+ VaultToken string `json:"vaulttoken"`
+ DisableTLS bool `json:"disable_tls"`
+ BackendAddressEnvVariable string `json:"smsdburlenv"`
}
// SMSConfig is the structure that stores the configuration
@@ -734,12 +957,19 @@ func ReadConfigFile(file string) (*SMSConfiguration, error) <span class="cov10"
}</span>
<span class="cov6" title="2">defer f.Close()
- SMSConfig = &amp;SMSConfiguration{}
+ // Default behaviour is to enable TLS
+ SMSConfig = &amp;SMSConfiguration{DisableTLS: false}
decoder := json.NewDecoder(f)
err = decoder.Decode(SMSConfig)
if err != nil </span><span class="cov0" title="0">{
return nil, err
}</span>
+
+ <span class="cov6" title="2">if SMSConfig.BackendAddress == "" &amp;&amp; SMSConfig.BackendAddressEnvVariable != "" </span><span class="cov0" title="0">{
+ // Get the value from ENV variable
+ smslogger.WriteInfo("Using Environment Variable: " + SMSConfig.BackendAddressEnvVariable)
+ SMSConfig.BackendAddress = os.Getenv(SMSConfig.BackendAddressEnvVariable)
+ }</span>
}
<span class="cov6" title="2">return SMSConfig, nil</span>
@@ -785,29 +1015,24 @@ func (h handler) createSecretDomainHandler(w http.ResponseWriter, r *http.Reques
var d smsbackend.SecretDomain
err := json.NewDecoder(r.Body).Decode(&amp;d)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil </span><span class="cov0" title="0">{
http.Error(w, err.Error(), http.StatusBadRequest)
return
}</span>
<span class="cov6" title="3">dom, err := h.secretBackend.CreateSecretDomain(d.Name)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil </span><span class="cov0" title="0">{
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}</span>
- <span class="cov6" title="3">jdata, err := json.Marshal(dom)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ <span class="cov6" title="3">w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(http.StatusCreated)
+ err = json.NewEncoder(w).Encode(dom)
+ if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil </span><span class="cov0" title="0">{
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}</span>
-
- <span class="cov6" title="3">w.Header().Set("Content-Type", "application/json")
- w.WriteHeader(http.StatusCreated)
- w.Write(jdata)</span>
}
// deleteSecretDomainHandler deletes a secret domain with the name provided
@@ -816,8 +1041,7 @@ func (h handler) deleteSecretDomainHandler(w http.ResponseWriter, r *http.Reques
domName := vars["domName"]
err := h.secretBackend.DeleteSecretDomain(domName)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "DeleteSecretDomainHandler") != nil </span><span class="cov0" title="0">{
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}</span>
@@ -834,15 +1058,13 @@ func (h handler) createSecretHandler(w http.ResponseWriter, r *http.Request) <sp
// Get secrets to be stored from body
var b smsbackend.Secret
err := json.NewDecoder(r.Body).Decode(&amp;b)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "CreateSecretHandler") != nil </span><span class="cov0" title="0">{
http.Error(w, err.Error(), http.StatusBadRequest)
return
}</span>
<span class="cov10" title="7">err = h.secretBackend.CreateSecret(domName, b)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "CreateSecretHandler") != nil </span><span class="cov0" title="0">{
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}</span>
@@ -857,21 +1079,17 @@ func (h handler) getSecretHandler(w http.ResponseWriter, r *http.Request) <span
secName := vars["secretName"]
sec, err := h.secretBackend.GetSecret(domName, secName)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "GetSecretHandler") != nil </span><span class="cov0" title="0">{
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}</span>
- <span class="cov10" title="7">jdata, err := json.Marshal(sec)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ <span class="cov10" title="7">w.Header().Set("Content-Type", "application/json")
+ err = json.NewEncoder(w).Encode(sec)
+ if smslogger.CheckError(err, "GetSecretHandler") != nil </span><span class="cov0" title="0">{
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}</span>
-
- <span class="cov10" title="7">w.Header().Set("Content-Type", "application/json")
- w.Write(jdata)</span>
}
// listSecretHandler handles listing all secrets under a particular domain name
@@ -880,8 +1098,7 @@ func (h handler) listSecretHandler(w http.ResponseWriter, r *http.Request) <span
domName := vars["domName"]
secList, err := h.secretBackend.ListSecret(domName)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "ListSecretHandler") != nil </span><span class="cov0" title="0">{
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}</span>
@@ -893,15 +1110,12 @@ func (h handler) listSecretHandler(w http.ResponseWriter, r *http.Request) <span
secList,
}
- jdata, err := json.Marshal(retStruct)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ w.Header().Set("Content-Type", "application/json")
+ err = json.NewEncoder(w).Encode(retStruct)
+ if smslogger.CheckError(err, "ListSecretHandler") != nil </span><span class="cov0" title="0">{
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}</span>
-
- <span class="cov6" title="3">w.Header().Set("Content-Type", "application/json")
- w.Write(jdata)</span>
}
// deleteSecretHandler handles deleting a secret by given domain name and secret name
@@ -911,37 +1125,34 @@ func (h handler) deleteSecretHandler(w http.ResponseWriter, r *http.Request) <sp
secName := vars["secretName"]
err := h.secretBackend.DeleteSecret(domName, secName)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "DeleteSecretHandler") != nil </span><span class="cov0" title="0">{
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}</span>
-}
-// struct that tracks various status items for SMS and backend
-type backendStatus struct {
- Seal bool `json:"sealstatus"`
+ <span class="cov10" title="7">w.WriteHeader(http.StatusNoContent)</span>
}
// statusHandler returns information related to SMS and SMS backend services
-func (h handler) statusHandler(w http.ResponseWriter, r *http.Request) <span class="cov6" title="3">{
+func (h handler) statusHandler(w http.ResponseWriter, r *http.Request) <span class="cov7" title="4">{
s, err := h.secretBackend.GetStatus()
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "StatusHandler") != nil </span><span class="cov0" title="0">{
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}</span>
- <span class="cov6" title="3">status := backendStatus{Seal: s}
- jdata, err := json.Marshal(status)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ <span class="cov7" title="4">status := struct {
+ Seal bool `json:"sealstatus"`
+ }{
+ s,
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ err = json.NewEncoder(w).Encode(status)
+ if smslogger.CheckError(err, "StatusHandler") != nil </span><span class="cov0" title="0">{
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}</span>
-
- <span class="cov6" title="3">w.Header().Set("Content-Type", "application/json")
- w.Write(jdata)</span>
}
// loginHandler handles login via password and username
@@ -961,15 +1172,53 @@ func (h handler) unsealHandler(w http.ResponseWriter, r *http.Request) <span cla
decoder := json.NewDecoder(r.Body)
decoder.DisallowUnknownFields()
err := decoder.Decode(&amp;inp)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "UnsealHandler") != nil </span><span class="cov0" title="0">{
http.Error(w, "Bad input JSON", http.StatusBadRequest)
return
}</span>
<span class="cov0" title="0">err = h.secretBackend.Unseal(inp.UnsealShard)
- if err != nil </span><span class="cov0" title="0">{
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "UnsealHandler") != nil </span><span class="cov0" title="0">{
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }</span>
+}
+
+// registerHandler allows the quorum clients to register with SMS
+// with their PGP public keys that are then used by sms for backend
+// initialization
+func (h handler) registerHandler(w http.ResponseWriter, r *http.Request) <span class="cov1" title="1">{
+ // Get shards to be used for unseal
+ type registerStruct struct {
+ PGPKey string `json:"pgpkey"`
+ QuorumID string `json:"quorumid"`
+ }
+
+ var inp registerStruct
+ decoder := json.NewDecoder(r.Body)
+ decoder.DisallowUnknownFields()
+ err := decoder.Decode(&amp;inp)
+ if smslogger.CheckError(err, "RegisterHandler") != nil </span><span class="cov0" title="0">{
+ http.Error(w, "Bad input JSON", http.StatusBadRequest)
+ return
+ }</span>
+
+ <span class="cov1" title="1">sh, err := h.secretBackend.RegisterQuorum(inp.PGPKey)
+ if smslogger.CheckError(err, "RegisterHandler") != nil </span><span class="cov0" title="0">{
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }</span>
+
+ // Creating a struct for return data
+ <span class="cov1" title="1">shStruct := struct {
+ Shard string `json:"shard"`
+ }{
+ sh,
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ err = json.NewEncoder(w).Encode(shStruct)
+ if smslogger.CheckError(err, "RegisterHandler") != nil </span><span class="cov0" title="0">{
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}</span>
@@ -987,8 +1236,9 @@ func CreateRouter(b smsbackend.SecretBackend) http.Handler <span class="cov4" ti
// Initialization APIs which will be used by quorum client
// to unseal and to provide root token to sms service
- router.HandleFunc("/v1/sms/status", h.statusHandler).Methods("GET")
- router.HandleFunc("/v1/sms/unseal", h.unsealHandler).Methods("POST")
+ router.HandleFunc("/v1/sms/quorum/status", h.statusHandler).Methods("GET")
+ router.HandleFunc("/v1/sms/quorum/unseal", h.unsealHandler).Methods("POST")
+ router.HandleFunc("/v1/sms/quorum/register", h.registerHandler).Methods("POST")
router.HandleFunc("/v1/sms/domain", h.createSecretDomainHandler).Methods("POST")
router.HandleFunc("/v1/sms/domain/{domName}", h.deleteSecretDomainHandler).Methods("DELETE")
@@ -1021,53 +1271,85 @@ func CreateRouter(b smsbackend.SecretBackend) http.Handler <span class="cov4" ti
package log
import (
+ "fmt"
"log"
"os"
)
-var errLogger *log.Logger
-var warnLogger *log.Logger
-var infoLogger *log.Logger
+var errL, warnL, infoL *log.Logger
+var stdErr, stdWarn, stdInfo *log.Logger
// Init will be called by sms.go before any other packages use it
-func Init(filePath string) <span class="cov8" title="1">{
- f, err := os.Create(filePath)
+func Init(filePath string) <span class="cov1" title="1">{
+
+ stdErr = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags)
+ stdWarn = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags)
+ stdInfo = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags)
+
+ if filePath == "" </span><span class="cov0" title="0">{
+ // We will just to std streams
+ return
+ }</span>
+
+ <span class="cov1" title="1">f, err := os.Create(filePath)
if err != nil </span><span class="cov0" title="0">{
- log.Println("Unable to create a log file")
- log.Println(err)
- errLogger = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags)
- warnLogger = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags)
- infoLogger = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags)
- }</span><span class="cov8" title="1"> else {
- errLogger = log.New(f, "ERROR: ", log.Lshortfile|log.LstdFlags)
- warnLogger = log.New(f, "WARNING: ", log.Lshortfile|log.LstdFlags)
- infoLogger = log.New(f, "INFO: ", log.Lshortfile|log.LstdFlags)
+ stdErr.Println("Unable to create log file: " + err.Error())
+ return
}</span>
+
+ <span class="cov1" title="1">errL = log.New(f, "ERROR: ", log.Lshortfile|log.LstdFlags)
+ warnL = log.New(f, "WARNING: ", log.Lshortfile|log.LstdFlags)
+ infoL = log.New(f, "INFO: ", log.Lshortfile|log.LstdFlags)</span>
}
// WriteError writes output to the writer we have
-// defined durint its creation with ERROR prefix
+// defined during its creation with ERROR prefix
func WriteError(msg string) <span class="cov0" title="0">{
- if errLogger != nil </span><span class="cov0" title="0">{
- errLogger.Println(msg)
+ if errL != nil </span><span class="cov0" title="0">{
+ errL.Output(2, fmt.Sprintln(msg))
+ }</span>
+ <span class="cov0" title="0">if stdErr != nil </span><span class="cov0" title="0">{
+ stdErr.Output(2, fmt.Sprintln(msg))
}</span>
}
// WriteWarn writes output to the writer we have
-// defined durint its creation with WARNING prefix
+// defined during its creation with WARNING prefix
func WriteWarn(msg string) <span class="cov0" title="0">{
- if warnLogger != nil </span><span class="cov0" title="0">{
- warnLogger.Println(msg)
+ if warnL != nil </span><span class="cov0" title="0">{
+ warnL.Output(2, fmt.Sprintln(msg))
+ }</span>
+ <span class="cov0" title="0">if stdWarn != nil </span><span class="cov0" title="0">{
+ stdWarn.Output(2, fmt.Sprintln(msg))
}</span>
}
// WriteInfo writes output to the writer we have
-// defined durint its creation with INFO prefix
-func WriteInfo(msg string) <span class="cov0" title="0">{
- if infoLogger != nil </span><span class="cov0" title="0">{
- infoLogger.Println(msg)
+// defined during its creation with INFO prefix
+func WriteInfo(msg string) <span class="cov1" title="1">{
+ if infoL != nil </span><span class="cov1" title="1">{
+ infoL.Output(2, fmt.Sprintln(msg))
+ }</span>
+ <span class="cov1" title="1">if stdInfo != nil </span><span class="cov1" title="1">{
+ stdInfo.Output(2, fmt.Sprintln(msg))
}</span>
}
+
+//CheckError is a helper function to reduce
+//repitition of error checkign blocks of code
+func CheckError(err error, topic string) error <span class="cov10" title="116">{
+ if err != nil </span><span class="cov1" title="1">{
+ msg := topic + ": " + err.Error()
+ if errL != nil </span><span class="cov1" title="1">{
+ errL.Output(2, fmt.Sprintln(msg))
+ }</span>
+ <span class="cov1" title="1">if stdErr != nil </span><span class="cov1" title="1">{
+ stdErr.Output(2, fmt.Sprintln(msg))
+ }</span>
+ <span class="cov1" title="1">return err</span>
+ }
+ <span class="cov9" title="115">return nil</span>
+}
</pre>
<pre class="file" id="file6" style="display: none">/*
@@ -1119,16 +1401,9 @@ func main() <span class="cov8" title="1">{
<span class="cov8" title="1">httpRouter := smshandler.CreateRouter(backendImpl)
- // TODO: Use CA certificate from AAF
- tlsConfig, err := smsauth.GetTLSConfig(smsConf.CAFile)
- if err != nil </span><span class="cov0" title="0">{
- log.Fatal(err)
- }</span>
-
- <span class="cov8" title="1">httpServer := &amp;http.Server{
- Handler: httpRouter,
- Addr: ":10443",
- TLSConfig: tlsConfig,
+ httpServer := &amp;http.Server{
+ Handler: httpRouter,
+ Addr: ":10443",
}
// Listener for SIGINT so that it returns cleanly
@@ -1141,8 +1416,22 @@ func main() <span class="cov8" title="1">{
close(connectionsClose)
}</span>()
- <span class="cov8" title="1">err = httpServer.ListenAndServeTLS(smsConf.ServerCert, smsConf.ServerKey)
- if err != nil &amp;&amp; err != http.ErrServerClosed </span><span class="cov0" title="0">{
+ // Start in TLS mode by default
+ <span class="cov8" title="1">if smsConf.DisableTLS == true </span><span class="cov0" title="0">{
+ smslogger.WriteWarn("TLS is Disabled")
+ err = httpServer.ListenAndServe()
+ }</span><span class="cov8" title="1"> else {
+ // TODO: Use CA certificate from AAF
+ tlsConfig, err := smsauth.GetTLSConfig(smsConf.CAFile)
+ if err != nil </span><span class="cov0" title="0">{
+ log.Fatal(err)
+ }</span>
+
+ <span class="cov8" title="1">httpServer.TLSConfig = tlsConfig
+ err = httpServer.ListenAndServeTLS(smsConf.ServerCert, smsConf.ServerKey)</span>
+ }
+
+ <span class="cov8" title="1">if err != nil &amp;&amp; err != http.ErrServerClosed </span><span class="cov0" title="0">{
log.Fatal(err)
}</span>
diff --git a/sms-service/src/sms/coverage.md b/sms-service/doc/coverage.md
index 6168342..6168342 100644
--- a/sms-service/src/sms/coverage.md
+++ b/sms-service/doc/coverage.md
diff --git a/sms-service/src/quorumclient/quorumclient.go b/sms-service/src/quorumclient/quorumclient.go
index 05fc967..dfa1a26 100644
--- a/sms-service/src/quorumclient/quorumclient.go
+++ b/sms-service/src/quorumclient/quorumclient.go
@@ -37,13 +37,13 @@ func loadPGPKeys(prKeyPath string, pbKeyPath string) (string, string, error) {
var pbkey, prkey string
generated := false
prkey, err := smsauth.ReadFromFile(prKeyPath)
- if err != nil {
- smslogger.WriteWarn("No Private Key found. Generating...")
+ if smslogger.CheckError(err, "LoadPGP Private Key") != nil {
+ smslogger.WriteInfo("No Private Key found. Generating...")
pbkey, prkey, _ = smsauth.GeneratePGPKeyPair()
generated = true
} else {
pbkey, err = smsauth.ReadFromFile(pbKeyPath)
- if err != nil {
+ if smslogger.CheckError(err, "LoadPGP Public Key") != nil {
smslogger.WriteWarn("No Public Key found. Generating...")
pbkey, prkey, _ = smsauth.GeneratePGPKeyPair()
generated = true
@@ -70,7 +70,7 @@ func main() {
prKeyPath := filepath.Join("auth", podName, "prkey")
shardPath := filepath.Join("auth", podName, "shard")
- smslogger.Init("")
+ smslogger.Init("quorum.log")
smslogger.WriteInfo("Starting Log for Quorum Client")
/*
@@ -80,7 +80,7 @@ func main() {
In Kubernetes, pod restarts will also change the hostname
*/
myID, err := smsauth.ReadFromFile(idFilePath)
- if err != nil {
+ if smslogger.CheckError(err, "Read ID") != nil {
smslogger.WriteWarn("Unable to find an ID for this client. Generating...")
myID, _ = uuid.GenerateUUID()
smsauth.WriteToFile(myID, idFilePath)
@@ -93,7 +93,7 @@ func main() {
*/
registrationDone := true
myShard, err := smsauth.ReadFromFile(shardPath)
- if err != nil {
+ if smslogger.CheckError(err, "Read Shard") != nil {
smslogger.WriteWarn("Unable to find a shard file. Registering with SMS...")
registrationDone = false
}
@@ -160,8 +160,7 @@ func main() {
//URL and Port is configured in config file
response, err := client.Get(cfg.BackEndURL + "/v1/sms/quorum/status")
- if err != nil {
- smslogger.WriteError("Unable to connect to SMS. Retrying...")
+ if smslogger.CheckError(err, "Connect to SMS") != nil {
continue
}
@@ -178,8 +177,7 @@ func main() {
if !registrationDone {
body := strings.NewReader(`{"pgpkey":"` + pbkey + `","quorumid":"` + myID + `"}`)
res, err := client.Post(cfg.BackEndURL+"/v1/sms/quorum/register", "application/json", body)
- if err != nil {
- smslogger.WriteError("Ran into error during registration. Retrying...")
+ if smslogger.CheckError(err, "Register with SMS") != nil {
continue
}
registrationDone = true
@@ -195,8 +193,8 @@ func main() {
body := strings.NewReader(`{"unsealshard":"` + decShard + `"}`)
//URL and PORT is configured via config file
response, err = client.Post(cfg.BackEndURL+"/v1/sms/quorum/unseal", "application/json", body)
- if err != nil {
- smslogger.WriteError("Error unsealing vault. Retrying... " + err.Error())
+ if smslogger.CheckError(err, "Unsealing Vault") != nil {
+ continue
}
}
}
diff --git a/sms-service/src/sms/Gopkg.lock b/sms-service/src/sms/Gopkg.lock
index c7684c7..2c09256 100644
--- a/sms-service/src/sms/Gopkg.lock
+++ b/sms-service/src/sms/Gopkg.lock
@@ -477,6 +477,6 @@
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
- inputs-digest = "d19e17a023506ab731b0f26c6fcfebe581d4d5194af094aecea5e72daddd3ead"
+ inputs-digest = "8280cde72a3ab78ad00d13c192de5920d188f3052f45884563896cab659469f9"
solver-name = "gps-cdcl"
solver-version = 1
diff --git a/sms-service/src/sms/auth/auth.go b/sms-service/src/sms/auth/auth.go
index 7172505..038e31d 100644
--- a/sms-service/src/sms/auth/auth.go
+++ b/sms-service/src/sms/auth/auth.go
@@ -29,39 +29,27 @@ import (
smslogger "sms/log"
)
-var tlsConfig *tls.Config
-
-func checkError(err error, topic string) error {
- if err != nil {
- smslogger.WriteError(topic + ": " + err.Error())
- return err
- }
-
- return nil
-}
-
// GetTLSConfig initializes a tlsConfig using the CA's certificate
// This config is then used to enable the server for mutual TLS
func GetTLSConfig(caCertFile string) (*tls.Config, error) {
+
// Initialize tlsConfig once
- if tlsConfig == nil {
- caCert, err := ioutil.ReadFile(caCertFile)
+ caCert, err := ioutil.ReadFile(caCertFile)
- if err != nil {
- return nil, err
- }
+ if err != nil {
+ return nil, err
+ }
- caCertPool := x509.NewCertPool()
- caCertPool.AppendCertsFromPEM(caCert)
+ caCertPool := x509.NewCertPool()
+ caCertPool.AppendCertsFromPEM(caCert)
- tlsConfig = &tls.Config{
- // Change to RequireAndVerify once we have mandatory certs
- ClientAuth: tls.VerifyClientCertIfGiven,
- ClientCAs: caCertPool,
- MinVersion: tls.VersionTLS12,
- }
- tlsConfig.BuildNameToCertificate()
+ tlsConfig := &tls.Config{
+ // Change to RequireAndVerify once we have mandatory certs
+ ClientAuth: tls.VerifyClientCertIfGiven,
+ ClientCAs: caCertPool,
+ MinVersion: tls.VersionTLS12,
}
+ tlsConfig.BuildNameToCertificate()
return tlsConfig, nil
}
@@ -70,22 +58,21 @@ func GetTLSConfig(caCertFile string) (*tls.Config, error) {
// A base64 encoded form of the public part of the entity
// A base64 encoded form of the private key
func GeneratePGPKeyPair() (string, string, error) {
+
var entity *openpgp.Entity
config := &packet.Config{
DefaultHash: crypto.SHA256,
}
entity, err := openpgp.NewEntity("aaf.sms.init", "PGP Key for unsealing", "", config)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create Entity") != nil {
return "", "", err
}
// Sign the identity in the entity
for _, id := range entity.Identities {
err = id.SelfSignature.SignUserId(id.UserId.Id, entity.PrimaryKey, entity.PrivateKey, nil)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Sign Entity") != nil {
return "", "", err
}
}
@@ -93,8 +80,7 @@ func GeneratePGPKeyPair() (string, string, error) {
// Sign the subkey in the entity
for _, subkey := range entity.Subkeys {
err := subkey.Sig.SignKey(subkey.PublicKey, entity.PrivateKey, nil)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Sign Subkey") != nil {
return "", "", err
}
}
@@ -113,32 +99,33 @@ func GeneratePGPKeyPair() (string, string, error) {
// EncryptPGPString takes data and a public key and encrypts using that
// public key
func EncryptPGPString(data string, pbKey string) (string, error) {
+
pbKeyBytes, err := base64.StdEncoding.DecodeString(pbKey)
- if checkError(err, "Decoding Base64 Public Key") != nil {
+ if smslogger.CheckError(err, "Decoding Base64 Public Key") != nil {
return "", err
}
dataBytes := []byte(data)
pbEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(pbKeyBytes)))
- if checkError(err, "Reading entity from PGP key") != nil {
+ if smslogger.CheckError(err, "Reading entity from PGP key") != nil {
return "", err
}
// encrypt string
buf := new(bytes.Buffer)
out, err := openpgp.Encrypt(buf, []*openpgp.Entity{pbEntity}, nil, nil, nil)
- if checkError(err, "Creating Encryption Pipe") != nil {
+ if smslogger.CheckError(err, "Creating Encryption Pipe") != nil {
return "", err
}
_, err = out.Write(dataBytes)
- if checkError(err, "Writing to Encryption Pipe") != nil {
+ if smslogger.CheckError(err, "Writing to Encryption Pipe") != nil {
return "", err
}
err = out.Close()
- if checkError(err, "Closing Encryption Pipe") != nil {
+ if smslogger.CheckError(err, "Closing Encryption Pipe") != nil {
return "", err
}
@@ -149,29 +136,26 @@ func EncryptPGPString(data string, pbKey string) (string, error) {
// DecryptPGPString decrypts a PGP encoded input string and returns
// a base64 representation of the decoded string
func DecryptPGPString(data string, prKey string) (string, error) {
+
// Convert private key to bytes from base64
prKeyBytes, err := base64.StdEncoding.DecodeString(prKey)
- if err != nil {
- smslogger.WriteError("Error Decoding base64 private key: " + err.Error())
+ if smslogger.CheckError(err, "Decoding Base64 Private Key") != nil {
return "", err
}
dataBytes, err := base64.StdEncoding.DecodeString(data)
- if err != nil {
- smslogger.WriteError("Error Decoding base64 data: " + err.Error())
+ if smslogger.CheckError(err, "Decoding base64 data") != nil {
return "", err
}
prEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(prKeyBytes)))
- if err != nil {
- smslogger.WriteError("Error reading entity from PGP key: " + err.Error())
+ if smslogger.CheckError(err, "Read Entity") != nil {
return "", err
}
prEntityList := &openpgp.EntityList{prEntity}
message, err := openpgp.ReadMessage(bytes.NewBuffer(dataBytes), prEntityList, nil, nil)
- if err != nil {
- smslogger.WriteError("Error Decrypting message: " + err.Error())
+ if smslogger.CheckError(err, "Decrypting Message") != nil {
return "", err
}
@@ -186,13 +170,10 @@ func DecryptPGPString(data string, prKey string) (string, error) {
func ReadFromFile(fileName string) (string, error) {
data, err := ioutil.ReadFile(fileName)
- if err != nil {
- smslogger.WriteError(err.Error())
- smslogger.WriteError("Cannot read file: " + fileName)
+ if smslogger.CheckError(err, "Read from file") != nil {
return "", err
}
return string(data), nil
-
}
// WriteToFile writes a PGP key into a file.
@@ -200,11 +181,8 @@ func ReadFromFile(fileName string) (string, error) {
func WriteToFile(data string, fileName string) error {
err := ioutil.WriteFile(fileName, []byte(data), 0600)
- if err != nil {
- smslogger.WriteError(err.Error())
- smslogger.WriteError("Cannot write to file: " + fileName)
+ if smslogger.CheckError(err, "Write to file") != nil {
return err
}
return nil
-
}
diff --git a/sms-service/src/sms/backend/backend.go b/sms-service/src/sms/backend/backend.go
index c137636..d7662ef 100644
--- a/sms-service/src/sms/backend/backend.go
+++ b/sms-service/src/sms/backend/backend.go
@@ -60,8 +60,7 @@ func InitSecretBackend() (SecretBackend, error) {
}
err := backendImpl.Init()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "InitSecretBackend") != nil {
return nil, err
}
diff --git a/sms-service/src/sms/backend/vault.go b/sms-service/src/sms/backend/vault.go
index e26baff..7fee097 100644
--- a/sms-service/src/sms/backend/vault.go
+++ b/sms-service/src/sms/backend/vault.go
@@ -56,9 +56,8 @@ func (v *Vault) initVaultClient() error {
vaultCFG := vaultapi.DefaultConfig()
vaultCFG.Address = v.vaultAddress
client, err := vaultapi.NewClient(vaultCFG)
- if err != nil {
- smslogger.WriteError(err.Error())
- return errors.New("Unable to create new vault client")
+ if smslogger.CheckError(err, "Create new vault client") != nil {
+ return err
}
v.initRoleDone = false
@@ -69,7 +68,6 @@ func (v *Vault) initVaultClient() error {
v.internalDomainMounted = false
v.prkey = ""
return nil
-
}
// Init will initialize the vault connection
@@ -84,8 +82,7 @@ func (v *Vault) Init() error {
v.initializeVault()
err := v.initRole()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "InitRole First Attempt") != nil {
smslogger.WriteInfo("InitRole will try again later")
}
@@ -94,10 +91,10 @@ func (v *Vault) Init() error {
// GetStatus returns the current seal status of vault
func (v *Vault) GetStatus() (bool, error) {
+
sys := v.vaultClient.Sys()
sealStatus, err := sys.SealStatus()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Getting Status") != nil {
return false, errors.New("Error getting status")
}
@@ -112,7 +109,7 @@ func (v *Vault) RegisterQuorum(pgpkey string) (string, error) {
defer v.Unlock()
if v.shards == nil {
- smslogger.WriteError("Invalid operation")
+ smslogger.WriteError("Invalid operation in RegisterQuorum")
return "", errors.New("Invalid operation")
}
// Pop the slice
@@ -133,10 +130,10 @@ func (v *Vault) RegisterQuorum(pgpkey string) (string, error) {
// Unseal is a passthrough API that allows any
// unseal or initialization processes for the backend
func (v *Vault) Unseal(shard string) error {
+
sys := v.vaultClient.Sys()
_, err := sys.Unseal(shard)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Unseal Operation") != nil {
return errors.New("Unable to execute unseal operation with specified shard")
}
@@ -147,17 +144,16 @@ func (v *Vault) Unseal(shard string) error {
// The secret itself is referenced via its name which translates to
// a mount path in vault
func (v *Vault) GetSecret(dom string, name string) (Secret, error) {
+
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Tocken Check") != nil {
return Secret{}, errors.New("Token check failed")
}
dom = v.vaultMountPrefix + "/" + dom
sec, err := v.vaultClient.Logical().Read(dom + "/" + name)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Read Secret") != nil {
return Secret{}, errors.New("Unable to read Secret at provided path")
}
@@ -173,17 +169,16 @@ func (v *Vault) GetSecret(dom string, name string) (Secret, error) {
// ListSecret returns a list of secret names on a particular domain
// The values of the secret are not returned
func (v *Vault) ListSecret(dom string) ([]string, error) {
+
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil {
return nil, errors.New("Token check failed")
}
dom = v.vaultMountPrefix + "/" + dom
sec, err := v.vaultClient.Logical().List(dom)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Read Secret") != nil {
return nil, errors.New("Unable to read Secret at provided path")
}
@@ -209,6 +204,7 @@ func (v *Vault) ListSecret(dom string) ([]string, error) {
// Mounts the internal Domain if its not already mounted
func (v *Vault) mountInternalDomain(name string) error {
+
if v.internalDomainMounted {
return nil
}
@@ -224,14 +220,13 @@ func (v *Vault) mountInternalDomain(name string) error {
}
err := v.vaultClient.Sys().Mount(mountPath, mountInput)
- if err != nil {
+ if smslogger.CheckError(err, "Mount internal Domain") != nil {
if strings.Contains(err.Error(), "existing mount") {
// It is already mounted
v.internalDomainMounted = true
return nil
}
// Ran into some other error mounting it.
- smslogger.WriteError(err.Error())
return errors.New("Unable to mount internal Domain")
}
@@ -242,16 +237,15 @@ func (v *Vault) mountInternalDomain(name string) error {
// Stores the UUID created for secretdomain in vault
// under v.vaultMountPrefix / smsinternal domain
func (v *Vault) storeUUID(uuid string, name string) error {
+
// Check if token is still valid
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil {
return errors.New("Token Check failed")
}
err = v.mountInternalDomain(v.internalDomain)
- if err != nil {
- smslogger.WriteError("Could not mount internal domain")
+ if smslogger.CheckError(err, "Mount Internal Domain") != nil {
return err
}
@@ -263,8 +257,7 @@ func (v *Vault) storeUUID(uuid string, name string) error {
}
err = v.CreateSecret(v.internalDomain, secret)
- if err != nil {
- smslogger.WriteError("Unable to write UUID to internal domain")
+ if smslogger.CheckError(err, "Write UUID to domain") != nil {
return err
}
@@ -273,10 +266,10 @@ func (v *Vault) storeUUID(uuid string, name string) error {
// CreateSecretDomain mounts the kv backend on a path with the given name
func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) {
+
// Check if token is still valid
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil {
return SecretDomain{}, errors.New("Token Check failed")
}
@@ -291,14 +284,13 @@ func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) {
}
err = v.vaultClient.Sys().Mount(mountPath, mountInput)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create Domain") != nil {
return SecretDomain{}, errors.New("Unable to create Secret Domain")
}
uuid, _ := uuid.GenerateUUID()
err = v.storeUUID(uuid, name)
- if err != nil {
+ if smslogger.CheckError(err, "Store UUID") != nil {
// Mount was successful at this point.
// Rollback the mount operation since we could not
// store the UUID for the mount.
@@ -312,9 +304,9 @@ func (v *Vault) CreateSecretDomain(name string) (SecretDomain, error) {
// CreateSecret creates a secret mounted on a particular domain name
// The secret itself is mounted on a path specified by name
func (v *Vault) CreateSecret(dom string, sec Secret) error {
+
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil {
return errors.New("Token check failed")
}
@@ -323,8 +315,7 @@ func (v *Vault) CreateSecret(dom string, sec Secret) error {
// Vault return is empty on successful write
// TODO: Check if values is not empty
_, err = v.vaultClient.Logical().Write(dom+"/"+sec.Name, sec.Values)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create Secret") != nil {
return errors.New("Unable to create Secret at provided path")
}
@@ -334,9 +325,9 @@ func (v *Vault) CreateSecret(dom string, sec Secret) error {
// DeleteSecretDomain deletes a secret domain which translates to
// an unmount operation on the given path in Vault
func (v *Vault) DeleteSecretDomain(name string) error {
+
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil {
return errors.New("Token Check Failed")
}
@@ -344,8 +335,7 @@ func (v *Vault) DeleteSecretDomain(name string) error {
mountPath := v.vaultMountPrefix + "/" + name
err = v.vaultClient.Sys().Unmount(mountPath)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Delete Domain") != nil {
return errors.New("Unable to delete domain specified")
}
@@ -356,8 +346,7 @@ func (v *Vault) DeleteSecretDomain(name string) error {
func (v *Vault) DeleteSecret(dom string, name string) error {
err := v.checkToken()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Token Check") != nil {
return errors.New("Token check failed")
}
@@ -365,8 +354,7 @@ func (v *Vault) DeleteSecret(dom string, name string) error {
// Vault return is empty on successful delete
_, err = v.vaultClient.Logical().Delete(dom + "/" + name)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Delete Secret") != nil {
return errors.New("Unable to delete Secret at provided path")
}
@@ -406,15 +394,13 @@ func (v *Vault) initRole() error {
rules := `path "sms/*" { capabilities = ["create", "read", "update", "delete", "list"] }
path "sys/mounts/sms*" { capabilities = ["update","delete","create"] }`
err := v.vaultClient.Sys().PutPolicy(v.policyName, rules)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Creating Policy") != nil {
return errors.New("Unable to create policy for approle creation")
}
//Check if applrole is mounted
authMounts, err := v.vaultClient.Sys().ListAuth()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Mount Auth Backend") != nil {
return errors.New("Unable to get mounted auth backends")
}
@@ -440,8 +426,7 @@ func (v *Vault) initRole() error {
// Create a role-id
v.vaultClient.Logical().Write("auth/approle/role/"+rName, data)
sec, err := v.vaultClient.Logical().Read("auth/approle/role/" + rName + "/role-id")
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create RoleID") != nil {
return errors.New("Unable to create role ID for approle")
}
v.roleID = sec.Data["role_id"].(string)
@@ -449,8 +434,7 @@ func (v *Vault) initRole() error {
// Create a secret-id to go with it
sec, err = v.vaultClient.Logical().Write("auth/approle/role/"+rName+"/secret-id",
map[string]interface{}{})
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create SecretID") != nil {
return errors.New("Unable to create secret ID for role")
}
@@ -462,8 +446,7 @@ func (v *Vault) initRole() error {
* using the unseal shards.
*/
err = v.vaultClient.Auth().Token().RevokeSelf(v.vaultToken)
- if err != nil {
- smslogger.WriteWarn(err.Error())
+ if smslogger.CheckError(err, "Revoke Root Token") != nil {
smslogger.WriteWarn("Unable to Revoke Token")
} else {
// Revoked successfully and clear it
@@ -481,6 +464,7 @@ func (v *Vault) initRole() error {
// Function checkToken() gets called multiple times to create
// temporary tokens
func (v *Vault) checkToken() error {
+
v.Lock()
defer v.Unlock()
@@ -501,8 +485,7 @@ func (v *Vault) checkToken() error {
// Create a temporary token using our roleID and secretID
out, err := v.vaultClient.Logical().Write("auth/approle/login",
map[string]interface{}{"role_id": v.roleID, "secret_id": v.secretID})
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Create Temp Token") != nil {
return errors.New("Unable to create Temporary Token for Role")
}
@@ -516,11 +499,12 @@ func (v *Vault) checkToken() error {
// vaultInit() is used to initialize the vault in cases where it is not
// initialized. This happens once during intial bring up.
func (v *Vault) initializeVault() error {
+
// Check for vault init status and don't exit till it is initialized
for {
init, err := v.vaultClient.Sys().InitStatus()
- if err != nil {
- smslogger.WriteError("Unable to get initStatus, trying again in 10s: " + err.Error())
+ if smslogger.CheckError(err, "Get Vault Init Status") != nil {
+ smslogger.WriteInfo("Trying again in 10s...")
time.Sleep(time.Second * 10)
continue
}
@@ -545,7 +529,7 @@ func (v *Vault) initializeVault() error {
pbkey, prkey, err := smsauth.GeneratePGPKeyPair()
- if err != nil {
+ if smslogger.CheckError(err, "Generating PGP Keys") != nil {
smslogger.WriteError("Error Generating PGP Keys. Vault Init will not use encryption!")
} else {
initReq.PGPKeys = []string{pbkey, pbkey, pbkey}
@@ -553,8 +537,7 @@ func (v *Vault) initializeVault() error {
}
resp, err := v.vaultClient.Sys().Init(initReq)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "Initialize Vault") != nil {
return errors.New("FATAL: Unable to initialize Vault")
}
diff --git a/sms-service/src/sms/backend/vault_test.go b/sms-service/src/sms/backend/vault_test.go
index 484c395..4862665 100644
--- a/sms-service/src/sms/backend/vault_test.go
+++ b/sms-service/src/sms/backend/vault_test.go
@@ -17,9 +17,11 @@
package backend
import (
+ vaultapi "github.com/hashicorp/vault/api"
credAppRole "github.com/hashicorp/vault/builtin/credential/approle"
vaulthttp "github.com/hashicorp/vault/http"
vaultlogical "github.com/hashicorp/vault/logical"
+ vaultinmem "github.com/hashicorp/vault/physical/inmem"
vaulttesting "github.com/hashicorp/vault/vault"
"reflect"
smslog "sms/log"
@@ -229,3 +231,39 @@ func TestDeleteSecret(t *testing.T) {
t.Fatal("DeleteSecret: Error Creating secret")
}
}
+
+func TestInitializeVault(t *testing.T) {
+
+ inm, err := vaultinmem.NewInmem(nil, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ core, err := vaulttesting.NewCore(&vaulttesting.CoreConfig{
+ DisableMlock: true,
+ DisableCache: true,
+ Physical: inm,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ln, addr := vaulthttp.TestServer(t, core)
+ defer ln.Close()
+
+ client, err := vaultapi.NewClient(&vaultapi.Config{
+ Address: addr,
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ v := &Vault{}
+ v.initVaultClient()
+ v.vaultClient = client
+
+ err = v.initializeVault()
+ if err != nil {
+ t.Fatal("InitializeVault: Error initializing Vault")
+ }
+}
diff --git a/sms-service/src/sms/handler/handler.go b/sms-service/src/sms/handler/handler.go
index dbf3f93..7ce9e01 100644
--- a/sms-service/src/sms/handler/handler.go
+++ b/sms-service/src/sms/handler/handler.go
@@ -37,15 +37,13 @@ func (h handler) createSecretDomainHandler(w http.ResponseWriter, r *http.Reques
var d smsbackend.SecretDomain
err := json.NewDecoder(r.Body).Decode(&d)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
dom, err := h.secretBackend.CreateSecretDomain(d.Name)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -53,8 +51,7 @@ func (h handler) createSecretDomainHandler(w http.ResponseWriter, r *http.Reques
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusCreated)
err = json.NewEncoder(w).Encode(dom)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "CreateSecretDomainHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -66,8 +63,7 @@ func (h handler) deleteSecretDomainHandler(w http.ResponseWriter, r *http.Reques
domName := vars["domName"]
err := h.secretBackend.DeleteSecretDomain(domName)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "DeleteSecretDomainHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -84,15 +80,13 @@ func (h handler) createSecretHandler(w http.ResponseWriter, r *http.Request) {
// Get secrets to be stored from body
var b smsbackend.Secret
err := json.NewDecoder(r.Body).Decode(&b)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "CreateSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
err = h.secretBackend.CreateSecret(domName, b)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "CreateSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -107,16 +101,14 @@ func (h handler) getSecretHandler(w http.ResponseWriter, r *http.Request) {
secName := vars["secretName"]
sec, err := h.secretBackend.GetSecret(domName, secName)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "GetSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(sec)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "GetSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -128,8 +120,7 @@ func (h handler) listSecretHandler(w http.ResponseWriter, r *http.Request) {
domName := vars["domName"]
secList, err := h.secretBackend.ListSecret(domName)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "ListSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -143,8 +134,7 @@ func (h handler) listSecretHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(retStruct)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "ListSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -157,8 +147,7 @@ func (h handler) deleteSecretHandler(w http.ResponseWriter, r *http.Request) {
secName := vars["secretName"]
err := h.secretBackend.DeleteSecret(domName, secName)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "DeleteSecretHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -169,8 +158,7 @@ func (h handler) deleteSecretHandler(w http.ResponseWriter, r *http.Request) {
// statusHandler returns information related to SMS and SMS backend services
func (h handler) statusHandler(w http.ResponseWriter, r *http.Request) {
s, err := h.secretBackend.GetStatus()
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "StatusHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -183,8 +171,7 @@ func (h handler) statusHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(status)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "StatusHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -207,15 +194,13 @@ func (h handler) unsealHandler(w http.ResponseWriter, r *http.Request) {
decoder := json.NewDecoder(r.Body)
decoder.DisallowUnknownFields()
err := decoder.Decode(&inp)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "UnsealHandler") != nil {
http.Error(w, "Bad input JSON", http.StatusBadRequest)
return
}
err = h.secretBackend.Unseal(inp.UnsealShard)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "UnsealHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -235,15 +220,13 @@ func (h handler) registerHandler(w http.ResponseWriter, r *http.Request) {
decoder := json.NewDecoder(r.Body)
decoder.DisallowUnknownFields()
err := decoder.Decode(&inp)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "RegisterHandler") != nil {
http.Error(w, "Bad input JSON", http.StatusBadRequest)
return
}
sh, err := h.secretBackend.RegisterQuorum(inp.PGPKey)
- if err != nil {
- smslogger.WriteError(err.Error())
+ if smslogger.CheckError(err, "RegisterHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
@@ -257,8 +240,7 @@ func (h handler) registerHandler(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(shStruct)
- if err != nil {
- smslogger.WriteError("Unable to encode response: " + err.Error())
+ if smslogger.CheckError(err, "RegisterHandler") != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
diff --git a/sms-service/src/sms/handler/handler_test.go b/sms-service/src/sms/handler/handler_test.go
index 52637f3..c1e55ed 100644
--- a/sms-service/src/sms/handler/handler_test.go
+++ b/sms-service/src/sms/handler/handler_test.go
@@ -48,7 +48,7 @@ func (b *TestBackend) Unseal(shard string) error {
}
func (b *TestBackend) RegisterQuorum(pgpkey string) (string, error) {
- return "", nil
+ return "N8z4eD2Zgv0eDJrgkkUq3Lh5n2p6Y1Zsui1NIHePlLU=", nil
}
func (b *TestBackend) GetSecret(dom string, sec string) (smsbackend.Secret, error) {
@@ -127,8 +127,49 @@ func TestStatusHandler(t *testing.T) {
}
}
+func TestRegisterHandler(t *testing.T) {
+ body := `{
+ "pgpkey":"asdasdasdasdgkjgljoiwera",
+ "quorumid":"123e4567-e89b-12d3-a456-426655440000"
+ }`
+ reader := strings.NewReader(body)
+ req, err := http.NewRequest("POST", "/v1/sms/quorum/register", reader)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ rr := httptest.NewRecorder()
+ hr := http.HandlerFunc(h.registerHandler)
+
+ hr.ServeHTTP(rr, req)
+
+ ret := rr.Code
+ if ret != http.StatusOK {
+ t.Errorf("registerHandler returned wrong status code: %v vs %v",
+ ret, http.StatusOK)
+ }
+
+ expected := struct {
+ Shard string `json:"shard"`
+ }{
+ "N8z4eD2Zgv0eDJrgkkUq3Lh5n2p6Y1Zsui1NIHePlLU=",
+ }
+ got := struct {
+ Shard string `json:"shard"`
+ }{}
+
+ json.NewDecoder(rr.Body).Decode(&got)
+
+ if reflect.DeepEqual(expected, got) == false {
+ t.Errorf("statusHandler returned unexpected body: got %v vs %v",
+ rr.Body.String(), expected)
+ }
+}
+
func TestUnsealHandler(t *testing.T) {
- req, err := http.NewRequest("GET", "/v1/sms/quorum/unseal", nil)
+ body := `{"unsealshard":"N8z4eD2Zgv0eDJrgkkUq3Lh5n2p6Y1Zsui1NIHePlLU="}`
+ reader := strings.NewReader(body)
+ req, err := http.NewRequest("POST", "/v1/sms/quorum/unseal", reader)
if err != nil {
t.Fatal(err)
}
diff --git a/sms-service/src/sms/log/logger.go b/sms-service/src/sms/log/logger.go
index 25da593..660f1ce 100644
--- a/sms-service/src/sms/log/logger.go
+++ b/sms-service/src/sms/log/logger.go
@@ -17,57 +17,82 @@
package log
import (
+ "fmt"
"log"
"os"
)
-var errLogger *log.Logger
-var warnLogger *log.Logger
-var infoLogger *log.Logger
+var errL, warnL, infoL *log.Logger
+var stdErr, stdWarn, stdInfo *log.Logger
// Init will be called by sms.go before any other packages use it
func Init(filePath string) {
+
+ stdErr = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags)
+ stdWarn = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags)
+ stdInfo = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags)
+
if filePath == "" {
- errLogger = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags)
- warnLogger = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags)
- infoLogger = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags)
+ // We will just to std streams
return
}
f, err := os.Create(filePath)
if err != nil {
- log.Println("Unable to create a log file")
- log.Println(err)
- errLogger = log.New(os.Stderr, "ERROR: ", log.Lshortfile|log.LstdFlags)
- warnLogger = log.New(os.Stdout, "WARNING: ", log.Lshortfile|log.LstdFlags)
- infoLogger = log.New(os.Stdout, "INFO: ", log.Lshortfile|log.LstdFlags)
- } else {
- errLogger = log.New(f, "ERROR: ", log.Lshortfile|log.LstdFlags)
- warnLogger = log.New(f, "WARNING: ", log.Lshortfile|log.LstdFlags)
- infoLogger = log.New(f, "INFO: ", log.Lshortfile|log.LstdFlags)
+ stdErr.Println("Unable to create log file: " + err.Error())
+ return
}
+
+ errL = log.New(f, "ERROR: ", log.Lshortfile|log.LstdFlags)
+ warnL = log.New(f, "WARNING: ", log.Lshortfile|log.LstdFlags)
+ infoL = log.New(f, "INFO: ", log.Lshortfile|log.LstdFlags)
}
// WriteError writes output to the writer we have
-// defined durint its creation with ERROR prefix
+// defined during its creation with ERROR prefix
func WriteError(msg string) {
- if errLogger != nil {
- errLogger.Println(msg)
+ if errL != nil {
+ errL.Output(2, fmt.Sprintln(msg))
+ }
+ if stdErr != nil {
+ stdErr.Output(2, fmt.Sprintln(msg))
}
}
// WriteWarn writes output to the writer we have
-// defined durint its creation with WARNING prefix
+// defined during its creation with WARNING prefix
func WriteWarn(msg string) {
- if warnLogger != nil {
- warnLogger.Println(msg)
+ if warnL != nil {
+ warnL.Output(2, fmt.Sprintln(msg))
+ }
+ if stdWarn != nil {
+ stdWarn.Output(2, fmt.Sprintln(msg))
}
}
// WriteInfo writes output to the writer we have
-// defined durint its creation with INFO prefix
+// defined during its creation with INFO prefix
func WriteInfo(msg string) {
- if infoLogger != nil {
- infoLogger.Println(msg)
+ if infoL != nil {
+ infoL.Output(2, fmt.Sprintln(msg))
+ }
+ if stdInfo != nil {
+ stdInfo.Output(2, fmt.Sprintln(msg))
+ }
+}
+
+//CheckError is a helper function to reduce
+//repitition of error checkign blocks of code
+func CheckError(err error, topic string) error {
+ if err != nil {
+ msg := topic + ": " + err.Error()
+ if errL != nil {
+ errL.Output(2, fmt.Sprintln(msg))
+ }
+ if stdErr != nil {
+ stdErr.Output(2, fmt.Sprintln(msg))
+ }
+ return err
}
+ return nil
}