diff options
Diffstat (limited to 'sms-service/src/sms/auth/auth.go')
-rw-r--r-- | sms-service/src/sms/auth/auth.go | 82 |
1 files changed, 30 insertions, 52 deletions
diff --git a/sms-service/src/sms/auth/auth.go b/sms-service/src/sms/auth/auth.go index 7172505..038e31d 100644 --- a/sms-service/src/sms/auth/auth.go +++ b/sms-service/src/sms/auth/auth.go @@ -29,39 +29,27 @@ import ( smslogger "sms/log" ) -var tlsConfig *tls.Config - -func checkError(err error, topic string) error { - if err != nil { - smslogger.WriteError(topic + ": " + err.Error()) - return err - } - - return nil -} - // GetTLSConfig initializes a tlsConfig using the CA's certificate // This config is then used to enable the server for mutual TLS func GetTLSConfig(caCertFile string) (*tls.Config, error) { + // Initialize tlsConfig once - if tlsConfig == nil { - caCert, err := ioutil.ReadFile(caCertFile) + caCert, err := ioutil.ReadFile(caCertFile) - if err != nil { - return nil, err - } + if err != nil { + return nil, err + } - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCert) + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) - tlsConfig = &tls.Config{ - // Change to RequireAndVerify once we have mandatory certs - ClientAuth: tls.VerifyClientCertIfGiven, - ClientCAs: caCertPool, - MinVersion: tls.VersionTLS12, - } - tlsConfig.BuildNameToCertificate() + tlsConfig := &tls.Config{ + // Change to RequireAndVerify once we have mandatory certs + ClientAuth: tls.VerifyClientCertIfGiven, + ClientCAs: caCertPool, + MinVersion: tls.VersionTLS12, } + tlsConfig.BuildNameToCertificate() return tlsConfig, nil } @@ -70,22 +58,21 @@ func GetTLSConfig(caCertFile string) (*tls.Config, error) { // A base64 encoded form of the public part of the entity // A base64 encoded form of the private key func GeneratePGPKeyPair() (string, string, error) { + var entity *openpgp.Entity config := &packet.Config{ DefaultHash: crypto.SHA256, } entity, err := openpgp.NewEntity("aaf.sms.init", "PGP Key for unsealing", "", config) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Create Entity") != nil { return "", "", err } // Sign the identity in the entity for _, id := range entity.Identities { err = id.SelfSignature.SignUserId(id.UserId.Id, entity.PrimaryKey, entity.PrivateKey, nil) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Sign Entity") != nil { return "", "", err } } @@ -93,8 +80,7 @@ func GeneratePGPKeyPair() (string, string, error) { // Sign the subkey in the entity for _, subkey := range entity.Subkeys { err := subkey.Sig.SignKey(subkey.PublicKey, entity.PrivateKey, nil) - if err != nil { - smslogger.WriteError(err.Error()) + if smslogger.CheckError(err, "Sign Subkey") != nil { return "", "", err } } @@ -113,32 +99,33 @@ func GeneratePGPKeyPair() (string, string, error) { // EncryptPGPString takes data and a public key and encrypts using that // public key func EncryptPGPString(data string, pbKey string) (string, error) { + pbKeyBytes, err := base64.StdEncoding.DecodeString(pbKey) - if checkError(err, "Decoding Base64 Public Key") != nil { + if smslogger.CheckError(err, "Decoding Base64 Public Key") != nil { return "", err } dataBytes := []byte(data) pbEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(pbKeyBytes))) - if checkError(err, "Reading entity from PGP key") != nil { + if smslogger.CheckError(err, "Reading entity from PGP key") != nil { return "", err } // encrypt string buf := new(bytes.Buffer) out, err := openpgp.Encrypt(buf, []*openpgp.Entity{pbEntity}, nil, nil, nil) - if checkError(err, "Creating Encryption Pipe") != nil { + if smslogger.CheckError(err, "Creating Encryption Pipe") != nil { return "", err } _, err = out.Write(dataBytes) - if checkError(err, "Writing to Encryption Pipe") != nil { + if smslogger.CheckError(err, "Writing to Encryption Pipe") != nil { return "", err } err = out.Close() - if checkError(err, "Closing Encryption Pipe") != nil { + if smslogger.CheckError(err, "Closing Encryption Pipe") != nil { return "", err } @@ -149,29 +136,26 @@ func EncryptPGPString(data string, pbKey string) (string, error) { // DecryptPGPString decrypts a PGP encoded input string and returns // a base64 representation of the decoded string func DecryptPGPString(data string, prKey string) (string, error) { + // Convert private key to bytes from base64 prKeyBytes, err := base64.StdEncoding.DecodeString(prKey) - if err != nil { - smslogger.WriteError("Error Decoding base64 private key: " + err.Error()) + if smslogger.CheckError(err, "Decoding Base64 Private Key") != nil { return "", err } dataBytes, err := base64.StdEncoding.DecodeString(data) - if err != nil { - smslogger.WriteError("Error Decoding base64 data: " + err.Error()) + if smslogger.CheckError(err, "Decoding base64 data") != nil { return "", err } prEntity, err := openpgp.ReadEntity(packet.NewReader(bytes.NewBuffer(prKeyBytes))) - if err != nil { - smslogger.WriteError("Error reading entity from PGP key: " + err.Error()) + if smslogger.CheckError(err, "Read Entity") != nil { return "", err } prEntityList := &openpgp.EntityList{prEntity} message, err := openpgp.ReadMessage(bytes.NewBuffer(dataBytes), prEntityList, nil, nil) - if err != nil { - smslogger.WriteError("Error Decrypting message: " + err.Error()) + if smslogger.CheckError(err, "Decrypting Message") != nil { return "", err } @@ -186,13 +170,10 @@ func DecryptPGPString(data string, prKey string) (string, error) { func ReadFromFile(fileName string) (string, error) { data, err := ioutil.ReadFile(fileName) - if err != nil { - smslogger.WriteError(err.Error()) - smslogger.WriteError("Cannot read file: " + fileName) + if smslogger.CheckError(err, "Read from file") != nil { return "", err } return string(data), nil - } // WriteToFile writes a PGP key into a file. @@ -200,11 +181,8 @@ func ReadFromFile(fileName string) (string, error) { func WriteToFile(data string, fileName string) error { err := ioutil.WriteFile(fileName, []byte(data), 0600) - if err != nil { - smslogger.WriteError(err.Error()) - smslogger.WriteError("Cannot write to file: " + fileName) + if smslogger.CheckError(err, "Write to file") != nil { return err } return nil - } |