diff options
Diffstat (limited to 'server/resty')
102 files changed, 19979 insertions, 0 deletions
diff --git a/server/resty/evp.lua b/server/resty/evp.lua new file mode 100644 index 0000000..584ff5a --- /dev/null +++ b/server/resty/evp.lua @@ -0,0 +1,804 @@ +local ffi = require "ffi" +local ffi_copy = ffi.copy +local ffi_gc = ffi.gc +local ffi_new = ffi.new +local ffi_string = ffi.string +local ffi_cast = ffi.cast +local _C = ffi.C + +local _M = { _VERSION = "0.2.3" } + +local ngx = ngx + + +local CONST = { + SHA256_DIGEST = "SHA256", + SHA512_DIGEST = "SHA512", + -- ref : https://github.com/openssl/openssl/blob/master/include/openssl/rsa.h + RSA_PKCS1_PADDING = 1, + RSA_SSLV23_PADDING = 2, + RSA_NO_PADDING = 3, + RSA_PKCS1_OAEP_PADDING = 4, + RSA_X931_PADDING = 5, + RSA_PKCS1_PSS_PADDING = 6, + -- ref : https://github.com/openssl/openssl/blob/master/include/openssl/evp.h + NID_rsaEncryption = 6, + EVP_PKEY_RSA = 6, + EVP_PKEY_ALG_CTRL = 0x1000, + EVP_PKEY_CTRL_RSA_PADDING = 0x1000 + 1, + + EVP_PKEY_OP_TYPE_CRYPT = 768, + EVP_PKEY_CTRL_RSA_OAEP_MD = 0x1000 + 9 +} +_M.CONST = CONST + + +-- Reference: https://wiki.openssl.org/index.php/EVP_Signing_and_Verifying +ffi.cdef[[ +// Error handling +unsigned long ERR_get_error(void); +const char * ERR_reason_error_string(unsigned long e); + +// Basic IO +typedef struct bio_st BIO; +typedef struct bio_method_st BIO_METHOD; +BIO_METHOD *BIO_s_mem(void); +BIO * BIO_new(BIO_METHOD *type); +int BIO_puts(BIO *bp,const char *buf); +void BIO_vfree(BIO *a); +int BIO_write(BIO *b, const void *buf, int len); + +// RSA +typedef struct rsa_st RSA; +int RSA_size(const RSA *rsa); +void RSA_free(RSA *rsa); +typedef int pem_password_cb(char *buf, int size, int rwflag, void *userdata); +RSA * PEM_read_bio_RSAPrivateKey(BIO *bp, RSA **rsa, pem_password_cb *cb, + void *u); +RSA * PEM_read_bio_RSAPublicKey(BIO *bp, RSA **rsa, pem_password_cb *cb, + void *u); + +// EC_KEY +typedef struct ec_key_st EC_KEY; +void EC_KEY_free(EC_KEY *key); +EC_KEY * PEM_read_bio_ECPrivateKey(BIO *bp, EC_KEY **key, pem_password_cb *cb, + void *u); +EC_KEY * PEM_read_bio_ECPublicKey(BIO *bp, EC_KEY **key, pem_password_cb *cb, + void *u); +// EVP PKEY +typedef struct evp_pkey_st EVP_PKEY; +typedef struct engine_st ENGINE; +EVP_PKEY *EVP_PKEY_new(void); +int EVP_PKEY_set1_RSA(EVP_PKEY *pkey,RSA *key); +int EVP_PKEY_set1_EC_KEY(EVP_PKEY *pkey,EC_KEY *key); +EVP_PKEY *EVP_PKEY_new_mac_key(int type, ENGINE *e, + const unsigned char *key, int keylen); +void EVP_PKEY_free(EVP_PKEY *key); +int i2d_RSA(RSA *a, unsigned char **out); + +// Additional typedef of ECC operations (DER/RAW sig conversion) +typedef struct bignum_st BIGNUM; +BIGNUM *BN_new(void); +void BN_free(BIGNUM *a); +int BN_num_bits(const BIGNUM *a); +int BN_bn2bin(const BIGNUM *a, unsigned char *to); +BIGNUM *BN_bin2bn(const unsigned char *s, int len, BIGNUM *ret); +char *BN_bn2hex(const BIGNUM *a); + + +typedef struct ECDSA_SIG_st { + BIGNUM *r; + BIGNUM *s;} ECDSA_SIG; +ECDSA_SIG* ECDSA_SIG_new(void); +int i2d_ECDSA_SIG(const ECDSA_SIG *sig, unsigned char **pp); +ECDSA_SIG* d2i_ECDSA_SIG(ECDSA_SIG **sig, unsigned char **pp, +long len); +void ECDSA_SIG_free(ECDSA_SIG *sig); + +typedef struct ecgroup_st EC_GROUP; + +EC_GROUP *EC_KEY_get0_group(const EC_KEY *key); +EC_KEY *EVP_PKEY_get0_EC_KEY(EVP_PKEY *pkey); +int EC_GROUP_get_order(const EC_GROUP *group, BIGNUM *order, void *ctx); + + +// PUBKEY +EVP_PKEY *PEM_read_bio_PUBKEY(BIO *bp, EVP_PKEY **x, + pem_password_cb *cb, void *u); + +// X509 +typedef struct x509_st X509; +X509 *PEM_read_bio_X509(BIO *bp, X509 **x, pem_password_cb *cb, void *u); +EVP_PKEY * X509_get_pubkey(X509 *x); +void X509_free(X509 *a); +void EVP_PKEY_free(EVP_PKEY *key); +int i2d_X509(X509 *a, unsigned char **out); +X509 *d2i_X509_bio(BIO *bp, X509 **x); + +// X509 store +typedef struct x509_store_st X509_STORE; +typedef struct X509_crl_st X509_CRL; +X509_STORE *X509_STORE_new(void ); +int X509_STORE_add_cert(X509_STORE *ctx, X509 *x); + // Use this if we want to load the certs directly from a variables +int X509_STORE_add_crl(X509_STORE *ctx, X509_CRL *x); +int X509_STORE_load_locations (X509_STORE *ctx, + const char *file, const char *dir); +void X509_STORE_free(X509_STORE *v); + +// X509 store context +typedef struct x509_store_ctx_st X509_STORE_CTX; +X509_STORE_CTX *X509_STORE_CTX_new(void); +int X509_STORE_CTX_init(X509_STORE_CTX *ctx, X509_STORE *store, + X509 *x509, void *chain); +int X509_verify_cert(X509_STORE_CTX *ctx); +void X509_STORE_CTX_cleanup(X509_STORE_CTX *ctx); +int X509_STORE_CTX_get_error(X509_STORE_CTX *ctx); +const char *X509_verify_cert_error_string(long n); +void X509_STORE_CTX_free(X509_STORE_CTX *ctx); + +// EVP Sign/Verify +typedef struct env_md_ctx_st EVP_MD_CTX; +typedef struct env_md_st EVP_MD; +typedef struct evp_pkey_ctx_st EVP_PKEY_CTX; +const EVP_MD *EVP_get_digestbyname(const char *name); + +//OpenSSL 1.0 +EVP_MD_CTX *EVP_MD_CTX_create(void); +void EVP_MD_CTX_destroy(EVP_MD_CTX *ctx); + +//OpenSSL 1.1 +EVP_MD_CTX *EVP_MD_CTX_new(void); +void EVP_MD_CTX_free(EVP_MD_CTX *ctx); + +int EVP_DigestInit_ex(EVP_MD_CTX *ctx, const EVP_MD *type, ENGINE *impl); +int EVP_DigestSignInit(EVP_MD_CTX *ctx, EVP_PKEY_CTX **pctx, + const EVP_MD *type, ENGINE *e, EVP_PKEY *pkey); +int EVP_DigestUpdate(EVP_MD_CTX *ctx,const void *d, + size_t cnt); +int EVP_DigestSignFinal(EVP_MD_CTX *ctx, + unsigned char *sigret, size_t *siglen); + +int EVP_DigestVerifyInit(EVP_MD_CTX *ctx, EVP_PKEY_CTX **pctx, + const EVP_MD *type, ENGINE *e, EVP_PKEY *pkey); +int EVP_DigestVerifyFinal(EVP_MD_CTX *ctx, + unsigned char *sig, size_t siglen); + +// Fingerprints +int X509_digest(const X509 *data,const EVP_MD *type, + unsigned char *md, unsigned int *len); + +//EVP encrypt decrypt +EVP_PKEY_CTX *EVP_PKEY_CTX_new(EVP_PKEY *pkey, ENGINE *e); +void EVP_PKEY_CTX_free(EVP_PKEY_CTX *ctx); + +int EVP_PKEY_CTX_ctrl(EVP_PKEY_CTX *ctx, int keytype, int optype, + int cmd, int p1, void *p2); + +int EVP_PKEY_size(EVP_PKEY *pkey); + +int EVP_PKEY_encrypt_init(EVP_PKEY_CTX *ctx); +int EVP_PKEY_encrypt(EVP_PKEY_CTX *ctx, + unsigned char *out, size_t *outlen, + const unsigned char *in, size_t inlen); + +int EVP_PKEY_decrypt_init(EVP_PKEY_CTX *ctx); +int EVP_PKEY_decrypt(EVP_PKEY_CTX *ctx, + unsigned char *out, size_t *outlen, + const unsigned char *in, size_t inlen); + + +]] + + +local function _err(ret) + -- The openssl error queue can have multiple items, print them all separated by ': ' + local errs = {} + local code = _C.ERR_get_error() + while code ~= 0 do + table.insert(errs, 1, ffi_string(_C.ERR_reason_error_string(code))) + code = _C.ERR_get_error() + end + + if #errs == 0 then + return ret, "Zero error code (null arguments?)" + end + return ret, table.concat(errs, ": ") +end + +local ctx_new, ctx_free +local openssl11, e = pcall(function () + local ctx = _C.EVP_MD_CTX_new() + _C.EVP_MD_CTX_free(ctx) +end) + +ngx.log(ngx.DEBUG, "openssl11=", openssl11, " err=", e) + +if openssl11 then + ctx_new = function () + return _C.EVP_MD_CTX_new() + end + ctx_free = function (ctx) + ffi_gc(ctx, _C.EVP_MD_CTX_free) + end +else + ctx_new = function () + local ctx = _C.EVP_MD_CTX_create() + return ctx + end + ctx_free = function (ctx) + ffi_gc(ctx, _C.EVP_MD_CTX_destroy) + end +end + +local function _new_key(self, opts) + local bio = _C.BIO_new(_C.BIO_s_mem()) + ffi_gc(bio, _C.BIO_vfree) + if _C.BIO_puts(bio, opts.pem_private_key) < 0 then + return _err() + end + + local pass + if opts.password then + local plen = #opts.password + pass = ffi_new("unsigned char[?]", plen + 1) + ffi_copy(pass, opts.password, plen) + end + + local key = nil + if self.algo == "RSA" then + key = _C.PEM_read_bio_RSAPrivateKey(bio, nil, nil, pass) + ffi_gc(key, _C.RSA_free) + elseif self.algo == "ECDSA" then + key = _C.PEM_read_bio_ECPrivateKey(bio, nil, nil, pass) + ffi_gc(key, _C.EC_KEY_free) + end + + if not key then + return _err() + end + + local evp_pkey = _C.EVP_PKEY_new() + if evp_pkey == nil then + return _err() + end + + ffi_gc(evp_pkey, _C.EVP_PKEY_free) + if self.algo == "RSA" then + if _C.EVP_PKEY_set1_RSA(evp_pkey, key) ~= 1 then + return _err() + end + elseif self.algo == "ECDSA" then + if _C.EVP_PKEY_set1_EC_KEY(evp_pkey, key) ~= 1 then + return _err() + end + end + + self.evp_pkey = evp_pkey + return self, nil +end + +local function _create_evp_ctx(self, encrypt) + self.ctx = _C.EVP_PKEY_CTX_new(self.evp_pkey, nil) + if self.ctx == nil then + return _err() + end + + ffi_gc(self.ctx, _C.EVP_PKEY_CTX_free) + + local md = _C.EVP_get_digestbyname(self.digest_alg) + if ffi_cast("void *", md) == nil then + return nil, "Unknown message digest" + end + + if encrypt then + if _C.EVP_PKEY_encrypt_init(self.ctx) <= 0 then + return _err() + end + else + if _C.EVP_PKEY_decrypt_init(self.ctx) <= 0 then + return _err() + end + end + + if _C.EVP_PKEY_CTX_ctrl(self.ctx, CONST.EVP_PKEY_RSA, -1, CONST.EVP_PKEY_CTRL_RSA_PADDING, + self.padding, nil) <= 0 then + return _err() + end + + if self.padding == CONST.RSA_PKCS1_OAEP_PADDING then + if _C.EVP_PKEY_CTX_ctrl(self.ctx, CONST.EVP_PKEY_RSA, CONST.EVP_PKEY_OP_TYPE_CRYPT, + CONST.EVP_PKEY_CTRL_RSA_OAEP_MD, 0, ffi_cast("void *", md)) <= 0 then + return _err() + end + end + + return self.ctx +end + +local RSASigner = {algo="RSA"} +_M.RSASigner = RSASigner + +--- Create a new RSASigner +-- @param pem_private_key A private key string in PEM format +-- @param password password for the private key (if required) +-- @returns RSASigner, err_string +function RSASigner.new(self, pem_private_key, password) + return _new_key ( + self, + { + pem_private_key = pem_private_key, + password = password + } + ) +end + + +--- Sign a message +-- @param message The message to sign +-- @param digest_name The digest format to use (e.g., "SHA256") +-- @returns signature, error_string +function RSASigner.sign(self, message, digest_name) + local buf = ffi_new("unsigned char[?]", 1024) + local len = ffi_new("size_t[1]", 1024) + + local ctx = ctx_new() + if ctx == nil then + return _err() + end + ctx_free(ctx) + + local md = _C.EVP_get_digestbyname(digest_name) + if md == nil then + return _err() + end + + if _C.EVP_DigestInit_ex(ctx, md, nil) ~= 1 then + return _err() + end + + local ret = _C.EVP_DigestSignInit(ctx, nil, md, nil, self.evp_pkey) + if ret ~= 1 then + return _err() + end + if _C.EVP_DigestUpdate(ctx, message, #message) ~= 1 then + return _err() + end + if _C.EVP_DigestSignFinal(ctx, buf, len) ~= 1 then + return _err() + end + return ffi_string(buf, len[0]), nil +end + + +local ECSigner = {algo="ECDSA"} +_M.ECSigner = ECSigner + +--- Create a new ECSigner +-- @param pem_private_key A private key string in PEM format +-- @param password password for the private key (if required) +-- @returns ECSigner, err_string +function ECSigner.new(self, pem_private_key, password) + return RSASigner.new(self, pem_private_key, password) +end + +--- Sign a message with ECDSA +-- @param message The message to sign +-- @param digest_name The digest format to use (e.g., "SHA256") +-- @returns signature, error_string +function ECSigner.sign(self, message, digest_name) + return RSASigner.sign(self, message, digest_name) +end + +--- Converts a ASN.1 DER signature to RAW r,s +-- @param signature The ASN.1 DER signature +-- @returns signature, error_string +function ECSigner.get_raw_sig(self, signature) + if not signature then + return nil, "Must pass a signature to convert" + end + local sig_ptr = ffi_new("unsigned char *[1]") + local sig_bin = ffi_new("unsigned char [?]", #signature) + ffi_copy(sig_bin, signature, #signature) + + sig_ptr[0] = sig_bin + local sig = _C.d2i_ECDSA_SIG(nil, sig_ptr, #signature) + ffi_gc(sig, _C.ECDSA_SIG_free) + + local rbytes = math.floor((_C.BN_num_bits(sig.r)+7)/8) + local sbytes = math.floor((_C.BN_num_bits(sig.s)+7)/8) + + -- Ensure we copy the BN in a padded form + local ec = _C.EVP_PKEY_get0_EC_KEY(self.evp_pkey) + local ecgroup = _C.EC_KEY_get0_group(ec) + + local order = _C.BN_new() + ffi_gc(order, _C.BN_free) + + -- res is an int, if 0, curve not found + local res = _C.EC_GROUP_get_order(ecgroup, order, nil) + + -- BN_num_bytes is a #define, so have to use BN_num_bits + local order_size_bytes = math.floor((_C.BN_num_bits(order)+7)/8) + local resbuf_len = order_size_bytes *2 + local resbuf = ffi_new("unsigned char[?]", resbuf_len) + + -- Let's whilst preserving MSB + _C.BN_bn2bin(sig.r, resbuf + order_size_bytes - rbytes) + _C.BN_bn2bin(sig.s, resbuf + (order_size_bytes*2) - sbytes) + + local raw = ffi_string(resbuf, resbuf_len) + return raw, nil +end + +local RSAVerifier = {} +_M.RSAVerifier = RSAVerifier + + +--- Create a new RSAVerifier +-- @param key_source An instance of Cert or PublicKey used for verification +-- @returns RSAVerifier, error_string +function RSAVerifier.new(self, key_source) + if not key_source then + return nil, "You must pass in an key_source for a public key" + end + local evp_public_key = key_source.public_key + self.evp_pkey = evp_public_key + return self, nil +end + +--- Verify a message is properly signed +-- @param message The original message +-- @param the signature to verify +-- @param digest_name The digest type that was used to sign +-- @returns bool, error_string +function RSAVerifier.verify(self, message, sig, digest_name) + local md = _C.EVP_get_digestbyname(digest_name) + if md == nil then + return _err(false) + end + + local ctx = ctx_new() + if ctx == nil then + return _err(false) + end + ctx_free(ctx) + + if _C.EVP_DigestInit_ex(ctx, md, nil) ~= 1 then + return _err(false) + end + + local ret = _C.EVP_DigestVerifyInit(ctx, nil, md, nil, self.evp_pkey) + if ret ~= 1 then + return _err(false) + end + if _C.EVP_DigestUpdate(ctx, message, #message) ~= 1 then + return _err(false) + end + local sig_bin = ffi_new("unsigned char[?]", #sig) + ffi_copy(sig_bin, sig, #sig) + if _C.EVP_DigestVerifyFinal(ctx, sig_bin, #sig) == 1 then + return true, nil + else + return false, "Verification failed" + end +end + +local ECVerifier = {} +_M.ECVerifier = ECVerifier +--- Create a new ECVerifier +-- @param key_source An instance of Cert or PublicKey used for verification +-- @returns ECVerifier, error_string +function ECVerifier.new(self, key_source) + return RSAVerifier.new(self, key_source) +end + +--- Verify a message is properly signed +-- @param message The original message +-- @param the signature to verify +-- @param digest_name The digest type that was used to sign +-- @returns bool, error_string +function ECVerifier.verify(self, message, sig, digest_name) + -- We have to convert the signature back from RAW to ASN1 for verification + local der_sig, err = self:get_der_sig(sig) + if not der_sig then + return nil, err + end + return RSAVerifier.verify(self, message, der_sig, digest_name) +end + +--- Converts a RAW r,s signature to ASN.1 DER signature (ECDSA) +-- @param signature The raw signature +-- @returns signature, error_string +function ECVerifier.get_der_sig(self, signature) + if not signature then + return nil, "Must pass a signature to convert" + end + -- inspired from https://bit.ly/2yZxzxJ + local ec = _C.EVP_PKEY_get0_EC_KEY(self.evp_pkey) + local ecgroup = _C.EC_KEY_get0_group(ec) + + local order = _C.BN_new() + ffi_gc(order, _C.BN_free) + + -- res is an int, if 0, curve not found + local res = _C.EC_GROUP_get_order(ecgroup, order, nil) + + -- BN_num_bytes is a #define, so have to use BN_num_bits + local order_size_bytes = math.floor((_C.BN_num_bits(order)+7)/8) + + if #signature ~= 2 * order_size_bytes then + return nil, "signature length != 2 * order length" + end + + local sig_bytes = ffi_new("unsigned char [?]", #signature) + ffi_copy(sig_bytes, signature, #signature) + local ecdsa = _C.ECDSA_SIG_new() + ffi_gc(ecdsa, _C.ECDSA_SIG_free) + + -- Those do not need to be GCed as they are cleared by the ECDSA_SIG_free() + local r = _C.BN_bin2bn(sig_bytes, order_size_bytes, nil) + local s = _C.BN_bin2bn(sig_bytes + order_size_bytes, order_size_bytes, nil) + + ecdsa.r = r + ecdsa.s = s + + -- Gives us the buffer size to allocate + local der_len = _C.i2d_ECDSA_SIG(ecdsa, nil) + + local der_sig_ptr = ffi_new("unsigned char *[1]") + local der_sig_bin = ffi_new("unsigned char [?]", der_len) + der_sig_ptr[0] = der_sig_bin + der_len = _C.i2d_ECDSA_SIG(ecdsa, der_sig_ptr) + + local der_str = ffi_string(der_sig_bin, der_len) + return der_str, nil +end + + +local Cert = {} +_M.Cert = Cert + + +--- Create a new Certificate object +-- @param payload A PEM or DER format X509 certificate +-- @returns Cert, error_string +function Cert.new(self, payload) + if not payload then + return nil, "Must pass a PEM or binary DER cert" + end + local bio = _C.BIO_new(_C.BIO_s_mem()) + ffi_gc(bio, _C.BIO_vfree) + local x509 + if payload:find('-----BEGIN') then + if _C.BIO_puts(bio, payload) < 0 then + return _err() + end + x509 = _C.PEM_read_bio_X509(bio, nil, nil, nil) + else + if _C.BIO_write(bio, payload, #payload) < 0 then + return _err() + end + x509 = _C.d2i_X509_bio(bio, nil) + end + if x509 == nil then + return _err() + end + ffi_gc(x509, _C.X509_free) + self.x509 = x509 + local public_key, err = self:get_public_key() + if not public_key then + return nil, err + end + + ffi_gc(public_key, _C.EVP_PKEY_free) + + self.public_key = public_key + return self, nil +end + + +--- Retrieve the DER format of the certificate +-- @returns Binary DER format, error_string +function Cert.get_der(self) + local bufp = ffi_new("unsigned char *[1]") + local len = _C.i2d_X509(self.x509, bufp) + if len < 0 then + return _err() + end + local der = ffi_string(bufp[0], len) + return der, nil +end + +--- Retrieve the cert fingerprint +-- @param digest_name the Type of digest to use (e.g., "SHA256") +-- @returns fingerprint_string, error_string +function Cert.get_fingerprint(self, digest_name) + local md = _C.EVP_get_digestbyname(digest_name) + if md == nil then + return _err() + end + local buf = ffi_new("unsigned char[?]", 32) + local len = ffi_new("unsigned int[1]", 32) + if _C.X509_digest(self.x509, md, buf, len) ~= 1 then + return _err() + end + local raw = ffi_string(buf, len[0]) + local t = {} + raw:gsub('.', function (c) table.insert(t, string.format('%02X', string.byte(c))) end) + return table.concat(t, ":"), nil +end + +--- Retrieve the public key from the CERT +-- @returns An OpenSSL EVP PKEY object representing the public key, error_string +function Cert.get_public_key(self) + local evp_pkey = _C.X509_get_pubkey(self.x509) + if evp_pkey == nil then + return _err() + end + + return evp_pkey, nil +end + +--- Verify the Certificate is trusted +-- @param trusted_cert_file File path to a list of PEM encoded trusted certificates +-- @return bool, error_string +function Cert.verify_trust(self, trusted_cert_file) + local store = _C.X509_STORE_new() + if store == nil then + return _err(false) + end + ffi_gc(store, _C.X509_STORE_free) + if _C.X509_STORE_load_locations(store, trusted_cert_file, nil) ~=1 then + return _err(false) + end + + local ctx = _C.X509_STORE_CTX_new() + if store == nil then + return _err(false) + end + ffi_gc(ctx, _C.X509_STORE_CTX_free) + if _C.X509_STORE_CTX_init(ctx, store, self.x509, nil) ~= 1 then + return _err(false) + end + + if _C.X509_verify_cert(ctx) ~= 1 then + local code = _C.X509_STORE_CTX_get_error(ctx) + local msg = ffi_string(_C.X509_verify_cert_error_string(code)) + _C.X509_STORE_CTX_cleanup(ctx) + return false, msg + end + _C.X509_STORE_CTX_cleanup(ctx) + return true, nil + +end + +local PublicKey = {} +_M.PublicKey = PublicKey + +--- Create a new PublicKey object +-- +-- If a PEM fornatted key is provided, the key must start with +-- +-- ----- BEGIN PUBLIC KEY ----- +-- +-- @param payload A PEM or DER format public key file +-- @return PublicKey, error_string +function PublicKey.new(self, payload) + if not payload then + return nil, "Must pass a PEM or binary DER public key" + end + local bio = _C.BIO_new(_C.BIO_s_mem()) + ffi_gc(bio, _C.BIO_vfree) + local pkey + if payload:find('-----BEGIN') then + if _C.BIO_puts(bio, payload) < 0 then + return _err() + end + pkey = _C.PEM_read_bio_PUBKEY(bio, nil, nil, nil) + else + if _C.BIO_write(bio, payload, #payload) < 0 then + return _err() + end + pkey = _C.d2i_PUBKEY_bio(bio, nil) + end + if pkey == nil then + return _err() + end + ffi_gc(pkey, _C.EVP_PKEY_free) + self.public_key = pkey + return self, nil +end + +local RSAEncryptor= {} +_M.RSAEncryptor = RSAEncryptor + +--- Create a new RSAEncryptor +-- @param key_source An instance of Cert or PublicKey used for verification +-- @param padding padding type to use +-- @param digest_alg digest algorithm to use +-- @returns RSAEncryptor, err_string +function RSAEncryptor.new(self, key_source, padding, digest_alg) + if not key_source then + return nil, "You must pass in an key_source for a public key" + end + local evp_public_key = key_source.public_key + self.evp_pkey = evp_public_key + self.padding = padding or CONST.RSA_PKCS1_OAEP_PADDING + self.digest_alg = digest_alg or CONST.SHA256_DIGEST + return self, nil +end + + + +--- Encrypts the payload +-- @param payload plain text payload +-- @returns encrypted payload, error_string +function RSAEncryptor.encrypt(self, payload) + + local ctx, err_str = _create_evp_ctx(self, true) + + if not ctx then + return nil, err_str + end + local len = ffi_new("size_t [1]") + if _C.EVP_PKEY_encrypt(ctx, nil, len, payload, #payload) <= 0 then + return _err() + end + local buf = ffi_new("unsigned char[?]", len[0]) + if _C.EVP_PKEY_encrypt(ctx, buf, len, payload, #payload) <= 0 then + return _err() + end + + return ffi_string(buf, len[0]) + +end + + +local RSADecryptor= {algo="RSA"} +_M.RSADecryptor = RSADecryptor + +--- Create a new RSADecryptor +-- @param pem_private_key A private key string in PEM format +-- @param password password for the private key (if required) +-- @param padding padding type to use +-- @param digest_alg digest algorithm to use +-- @returns RSADecryptor, error_string +function RSADecryptor.new(self, pem_private_key, password, padding, digest_alg) + self.padding = padding or CONST.RSA_PKCS1_OAEP_PADDING + self.digest_alg = digest_alg or CONST.SHA256_DIGEST + return _new_key ( + self, + { + pem_private_key = pem_private_key, + password = password + } + ) +end + +--- Decrypts the cypher text +-- @param cypher_text encrypted payload +-- @param padding rsa pading mode to use, Defaults to RSA_PKCS1_PADDING +function RSADecryptor.decrypt(self, cypher_text) + + local ctx, err_code, err_str = _create_evp_ctx(self, false) + + if not ctx then + return nil, err_code, err_str + end + + local len = ffi_new("size_t [1]") + if _C.EVP_PKEY_decrypt(ctx, nil, len, cypher_text, #cypher_text) <= 0 then + return _err() + end + + local buf = ffi_new("unsigned char[?]", len[0]) + if _C.EVP_PKEY_decrypt(ctx, buf, len, cypher_text, #cypher_text) <= 0 then + return _err() + end + + return ffi_string(buf, len[0]) + +end + +return _M diff --git a/server/resty/hmac.lua b/server/resty/hmac.lua new file mode 100644 index 0000000..8d94a8b --- /dev/null +++ b/server/resty/hmac.lua @@ -0,0 +1,167 @@ + +local str_util = require "resty.string" +local to_hex = str_util.to_hex +local ffi = require "ffi" +local ffi_new = ffi.new +local ffi_str = ffi.string +local ffi_gc = ffi.gc +local ffi_typeof = ffi.typeof +local C = ffi.C +local setmetatable = setmetatable +local error = error + + +local _M = { _VERSION = '0.04' } + +local mt = { __index = _M } + + +ffi.cdef[[ +typedef struct engine_st ENGINE; +typedef struct evp_pkey_ctx_st EVP_PKEY_CTX; +typedef struct evp_md_ctx_st EVP_MD_CTX; +typedef struct evp_md_st EVP_MD; +typedef struct hmac_ctx_st HMAC_CTX; + +//OpenSSL 1.0 +void HMAC_CTX_init(HMAC_CTX *ctx); +void HMAC_CTX_cleanup(HMAC_CTX *ctx); + +//OpenSSL 1.1 +HMAC_CTX *HMAC_CTX_new(void); +void HMAC_CTX_free(HMAC_CTX *ctx); + +int HMAC_Init_ex(HMAC_CTX *ctx, const void *key, int len, const EVP_MD *md, ENGINE *impl); +int HMAC_Update(HMAC_CTX *ctx, const unsigned char *data, size_t len); +int HMAC_Final(HMAC_CTX *ctx, unsigned char *md, unsigned int *len); + +const EVP_MD *EVP_md5(void); +const EVP_MD *EVP_sha1(void); +const EVP_MD *EVP_sha256(void); +const EVP_MD *EVP_sha512(void); +]] + +local buf = ffi_new("unsigned char[64]") +local res_len = ffi_new("unsigned int[1]") +local hashes = { + MD5 = C.EVP_md5(), + SHA1 = C.EVP_sha1(), + SHA256 = C.EVP_sha256(), + SHA512 = C.EVP_sha512() +} + +local ctx_new, ctx_free +local openssl11, e = pcall(function () + local ctx = C.HMAC_CTX_new() + C.HMAC_CTX_free(ctx) +end) +if openssl11 then + ctx_new = function () + return C.HMAC_CTX_new() + end + ctx_free = function (ctx) + C.HMAC_CTX_free(ctx) + end +else + ffi.cdef [[ + struct evp_md_ctx_st + { + const EVP_MD *digest; + ENGINE *engine; + unsigned long flags; + void *md_data; + EVP_PKEY_CTX *pctx; + int (*update)(EVP_MD_CTX *ctx,const void *data,size_t count); + }; + + struct evp_md_st + { + int type; + int pkey_type; + int md_size; + unsigned long flags; + int (*init)(EVP_MD_CTX *ctx); + int (*update)(EVP_MD_CTX *ctx,const void *data,size_t count); + int (*final)(EVP_MD_CTX *ctx,unsigned char *md); + int (*copy)(EVP_MD_CTX *to,const EVP_MD_CTX *from); + int (*cleanup)(EVP_MD_CTX *ctx); + + int (*sign)(int type, const unsigned char *m, unsigned int m_length, unsigned char *sigret, unsigned int *siglen, void *key); + int (*verify)(int type, const unsigned char *m, unsigned int m_length, const unsigned char *sigbuf, unsigned int siglen, void *key); + int required_pkey_type[5]; + int block_size; + int ctx_size; + int (*md_ctrl)(EVP_MD_CTX *ctx, int cmd, int p1, void *p2); + }; + + struct hmac_ctx_st + { + const EVP_MD *md; + EVP_MD_CTX md_ctx; + EVP_MD_CTX i_ctx; + EVP_MD_CTX o_ctx; + unsigned int key_length; + unsigned char key[128]; + }; + ]] + + local ctx_ptr_type = ffi_typeof("HMAC_CTX[1]") + + ctx_new = function () + local ctx = ffi_new(ctx_ptr_type) + C.HMAC_CTX_init(ctx) + return ctx + end + ctx_free = function (ctx) + C.HMAC_CTX_cleanup(ctx) + end +end + + +_M.ALGOS = hashes + + +function _M.new(self, key, hash_algo) + local ctx = ctx_new() + + local _hash_algo = hash_algo or hashes.MD5 + + if C.HMAC_Init_ex(ctx, key, #key, _hash_algo, nil) == 0 then + return nil + end + + ffi_gc(ctx, ctx_free) + + return setmetatable({ _ctx = ctx }, mt) +end + + +function _M.update(self, s) + return C.HMAC_Update(self._ctx, s, #s) == 1 +end + + +function _M.final(self, s, hex_output) + + if s ~= nil then + if C.HMAC_Update(self._ctx, s, #s) == 0 then + return nil + end + end + + if C.HMAC_Final(self._ctx, buf, res_len) == 1 then + if hex_output == true then + return to_hex(ffi_str(buf, res_len[0])) + end + return ffi_str(buf, res_len[0]) + end + + return nil +end + + +function _M.reset(self) + return C.HMAC_Init_ex(self._ctx, nil, 0, nil, nil) == 1 +end + +return _M diff --git a/server/resty/http.lua b/server/resty/http.lua new file mode 100644 index 0000000..70c3bee --- /dev/null +++ b/server/resty/http.lua @@ -0,0 +1,1178 @@ +local http_headers = require "resty.http_headers" + +local ngx = ngx +local ngx_socket_tcp = ngx.socket.tcp +local ngx_req = ngx.req +local ngx_req_socket = ngx_req.socket +local ngx_req_get_headers = ngx_req.get_headers +local ngx_req_get_method = ngx_req.get_method +local str_lower = string.lower +local str_upper = string.upper +local str_find = string.find +local str_sub = string.sub +local tbl_concat = table.concat +local tbl_insert = table.insert +local ngx_encode_args = ngx.encode_args +local ngx_re_match = ngx.re.match +local ngx_re_gmatch = ngx.re.gmatch +local ngx_re_sub = ngx.re.sub +local ngx_re_gsub = ngx.re.gsub +local ngx_re_find = ngx.re.find +local ngx_log = ngx.log +local ngx_DEBUG = ngx.DEBUG +local ngx_ERR = ngx.ERR +local ngx_var = ngx.var +local ngx_print = ngx.print +local ngx_header = ngx.header +local co_yield = coroutine.yield +local co_create = coroutine.create +local co_status = coroutine.status +local co_resume = coroutine.resume +local setmetatable = setmetatable +local tonumber = tonumber +local tostring = tostring +local unpack = unpack +local rawget = rawget +local select = select +local ipairs = ipairs +local pairs = pairs +local pcall = pcall +local type = type + + +-- http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html#sec13.5.1 +local HOP_BY_HOP_HEADERS = { + ["connection"] = true, + ["keep-alive"] = true, + ["proxy-authenticate"] = true, + ["proxy-authorization"] = true, + ["te"] = true, + ["trailers"] = true, + ["transfer-encoding"] = true, + ["upgrade"] = true, + ["content-length"] = true, -- Not strictly hop-by-hop, but Nginx will deal + -- with this (may send chunked for example). +} + + +local EXPECTING_BODY = { + POST = true, + PUT = true, + PATCH = true, +} + + +-- Reimplemented coroutine.wrap, returning "nil, err" if the coroutine cannot +-- be resumed. This protects user code from infinite loops when doing things like +-- repeat +-- local chunk, err = res.body_reader() +-- if chunk then -- <-- This could be a string msg in the core wrap function. +-- ... +-- end +-- until not chunk +local co_wrap = function(func) + local co = co_create(func) + if not co then + return nil, "could not create coroutine" + else + return function(...) + if co_status(co) == "suspended" then + return select(2, co_resume(co, ...)) + else + return nil, "can't resume a " .. co_status(co) .. " coroutine" + end + end + end +end + + +-- Returns a new table, recursively copied from the one given. +-- +-- @param table table to be copied +-- @return table +local function tbl_copy(orig) + local orig_type = type(orig) + local copy + if orig_type == "table" then + copy = {} + for orig_key, orig_value in next, orig, nil do + copy[tbl_copy(orig_key)] = tbl_copy(orig_value) + end + else -- number, string, boolean, etc + copy = orig + end + return copy +end + + +local _M = { + _VERSION = '0.17.0-beta.1', +} +_M._USER_AGENT = "lua-resty-http/" .. _M._VERSION .. " (Lua) ngx_lua/" .. ngx.config.ngx_lua_version + +local mt = { __index = _M } + + +local HTTP = { + [1.0] = " HTTP/1.0\r\n", + [1.1] = " HTTP/1.1\r\n", +} + + +local DEFAULT_PARAMS = { + method = "GET", + path = "/", + version = 1.1, +} + + +local DEBUG = false + + +function _M.new(_) + local sock, err = ngx_socket_tcp() + if not sock then + return nil, err + end + return setmetatable({ sock = sock, keepalive = true }, mt) +end + + +function _M.debug(d) + DEBUG = (d == true) +end + + +function _M.set_timeout(self, timeout) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + return sock:settimeout(timeout) +end + + +function _M.set_timeouts(self, connect_timeout, send_timeout, read_timeout) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + return sock:settimeouts(connect_timeout, send_timeout, read_timeout) +end + +do + local aio_connect = require "resty.http_connect" + -- Function signatures to support: + -- ok, err, ssl_session = httpc:connect(options_table) + -- ok, err = httpc:connect(host, port, options_table?) + -- ok, err = httpc:connect("unix:/path/to/unix.sock", options_table?) + function _M.connect(self, options, ...) + if type(options) == "table" then + -- all-in-one interface + return aio_connect(self, options) + else + -- backward compatible + return self:tcp_only_connect(options, ...) + end + end +end + +function _M.tcp_only_connect(self, ...) + ngx_log(ngx_DEBUG, "Use of deprecated `connect` method signature") + + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + self.host = select(1, ...) + self.port = select(2, ...) + + -- If port is not a number, this is likely a unix domain socket connection. + if type(self.port) ~= "number" then + self.port = nil + end + + self.keepalive = true + self.ssl = false + + return sock:connect(...) +end + + +function _M.set_keepalive(self, ...) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + if self.keepalive == true then + return sock:setkeepalive(...) + else + -- The server said we must close the connection, so we cannot setkeepalive. + -- If close() succeeds we return 2 instead of 1, to differentiate between + -- a normal setkeepalive() failure and an intentional close(). + local res, err = sock:close() + if res then + return 2, "connection must be closed" + else + return res, err + end + end +end + + +function _M.get_reused_times(self) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + return sock:getreusedtimes() +end + + +function _M.close(self) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + return sock:close() +end + + +local function _should_receive_body(method, code) + if method == "HEAD" then return nil end + if code == 204 or code == 304 then return nil end + if code >= 100 and code < 200 then return nil end + return true +end + + +function _M.parse_uri(_, uri, query_in_path) + if query_in_path == nil then query_in_path = true end + + local m, err = ngx_re_match( + uri, + [[^(?:(http[s]?):)?//((?:[^\[\]:/\?]+)|(?:\[.+\]))(?::(\d+))?([^\?]*)\??(.*)]], + "jo" + ) + + if not m then + if err then + return nil, "failed to match the uri: " .. uri .. ", " .. err + end + + return nil, "bad uri: " .. uri + else + -- If the URI is schemaless (i.e. //example.com) try to use our current + -- request scheme. + if not m[1] then + -- Schema-less URIs can occur in client side code, implying "inherit + -- the schema from the current request". We support it for a fairly + -- specific case; if for example you are using the ESI parser in + -- ledge (https://github.com/ledgetech/ledge) to perform in-flight + -- sub requests on the edge based on instructions found in markup, + -- those URIs may also be schemaless with the intention that the + -- subrequest would inherit the schema just like JavaScript would. + local scheme = ngx_var.scheme + if scheme == "http" or scheme == "https" then + m[1] = scheme + else + return nil, "schemaless URIs require a request context: " .. uri + end + end + + if m[3] then + m[3] = tonumber(m[3]) + else + if m[1] == "https" then + m[3] = 443 + else + m[3] = 80 + end + end + if not m[4] or "" == m[4] then m[4] = "/" end + + if query_in_path and m[5] and m[5] ~= "" then + m[4] = m[4] .. "?" .. m[5] + m[5] = nil + end + + return m, nil + end +end + + +local function _format_request(self, params) + local version = params.version + local headers = params.headers or {} + + local query = params.query or "" + if type(query) == "table" then + query = "?" .. ngx_encode_args(query) + elseif query ~= "" and str_sub(query, 1, 1) ~= "?" then + query = "?" .. query + end + + -- Initialize request + local req = { + str_upper(params.method), + " ", + self.path_prefix or "", + params.path, + query, + HTTP[version], + -- Pre-allocate slots for minimum headers and carriage return. + true, + true, + true, + } + local c = 7 -- req table index it's faster to do this inline vs table.insert + + -- Append headers + for key, values in pairs(headers) do + key = tostring(key) + + if type(values) == "table" then + for _, value in pairs(values) do + req[c] = key .. ": " .. tostring(value) .. "\r\n" + c = c + 1 + end + + else + req[c] = key .. ": " .. tostring(values) .. "\r\n" + c = c + 1 + end + end + + -- Close headers + req[c] = "\r\n" + + return tbl_concat(req) +end + + +local function _receive_status(sock) + local line, err = sock:receive("*l") + if not line then + return nil, nil, nil, err + end + + local version = tonumber(str_sub(line, 6, 8)) + if not version then + return nil, nil, nil, + "couldn't parse HTTP version from response status line: " .. line + end + + local status = tonumber(str_sub(line, 10, 12)) + if not status then + return nil, nil, nil, + "couldn't parse status code from response status line: " .. line + end + + local reason = str_sub(line, 14) + + return status, version, reason +end + + +local function _receive_headers(sock) + local headers = http_headers.new() + + repeat + local line, err = sock:receive("*l") + if not line then + return nil, err + end + + local m, err = ngx_re_match(line, "([^:\\s]+):\\s*(.*)", "jo") + if err then ngx_log(ngx_ERR, err) end + + if not m then + break + end + + local key = m[1] + local val = m[2] + if headers[key] then + if type(headers[key]) ~= "table" then + headers[key] = { headers[key] } + end + tbl_insert(headers[key], tostring(val)) + else + headers[key] = tostring(val) + end + until ngx_re_find(line, "^\\s*$", "jo") + + return headers, nil +end + + +local function transfer_encoding_is_chunked(headers) + local te = headers["Transfer-Encoding"] + if not te then + return false + end + + -- Handle duplicate headers + -- This shouldn't happen but can in the real world + if type(te) ~= "string" then + te = tbl_concat(te, ",") + end + + return str_find(str_lower(te), "chunked", 1, true) ~= nil +end +_M.transfer_encoding_is_chunked = transfer_encoding_is_chunked + + +local function _chunked_body_reader(sock, default_chunk_size) + return co_wrap(function(max_chunk_size) + local remaining = 0 + local length + max_chunk_size = max_chunk_size or default_chunk_size + + repeat + -- If we still have data on this chunk + if max_chunk_size and remaining > 0 then + + if remaining > max_chunk_size then + -- Consume up to max_chunk_size + length = max_chunk_size + remaining = remaining - max_chunk_size + else + -- Consume all remaining + length = remaining + remaining = 0 + end + else -- This is a fresh chunk + + -- Receive the chunk size + local str, err = sock:receive("*l") + if not str then + co_yield(nil, err) + end + + length = tonumber(str, 16) + + if not length then + co_yield(nil, "unable to read chunksize") + end + + if max_chunk_size and length > max_chunk_size then + -- Consume up to max_chunk_size + remaining = length - max_chunk_size + length = max_chunk_size + end + end + + if length > 0 then + local str, err = sock:receive(length) + if not str then + co_yield(nil, err) + end + + max_chunk_size = co_yield(str) or default_chunk_size + + -- If we're finished with this chunk, read the carriage return. + if remaining == 0 then + sock:receive(2) -- read \r\n + end + else + -- Read the last (zero length) chunk's carriage return + sock:receive(2) -- read \r\n + end + + until length == 0 + end) +end + + +local function _body_reader(sock, content_length, default_chunk_size) + return co_wrap(function(max_chunk_size) + max_chunk_size = max_chunk_size or default_chunk_size + + if not content_length and max_chunk_size then + -- We have no length, but wish to stream. + -- HTTP 1.0 with no length will close connection, so read chunks to the end. + repeat + local str, err, partial = sock:receive(max_chunk_size) + if not str and err == "closed" then + co_yield(partial, err) + end + + max_chunk_size = tonumber(co_yield(str) or default_chunk_size) + if max_chunk_size and max_chunk_size < 0 then max_chunk_size = nil end + + if not max_chunk_size then + ngx_log(ngx_ERR, "Buffer size not specified, bailing") + break + end + until not str + + elseif not content_length then + -- We have no length but don't wish to stream. + -- HTTP 1.0 with no length will close connection, so read to the end. + co_yield(sock:receive("*a")) + + elseif not max_chunk_size then + -- We have a length and potentially keep-alive, but want everything. + co_yield(sock:receive(content_length)) + + else + -- We have a length and potentially a keep-alive, and wish to stream + -- the response. + local received = 0 + repeat + local length = max_chunk_size + if received + length > content_length then + length = content_length - received + end + + if length > 0 then + local str, err = sock:receive(length) + if not str then + co_yield(nil, err) + end + received = received + length + + max_chunk_size = tonumber(co_yield(str) or default_chunk_size) + if max_chunk_size and max_chunk_size < 0 then max_chunk_size = nil end + + if not max_chunk_size then + ngx_log(ngx_ERR, "Buffer size not specified, bailing") + break + end + end + + until length == 0 + end + end) +end + + +local function _no_body_reader() + return nil +end + + +local function _read_body(res) + local reader = res.body_reader + + if not reader then + -- Most likely HEAD or 304 etc. + return nil, "no body to be read" + end + + local chunks = {} + local c = 1 + + local chunk, err + repeat + chunk, err = reader() + + if err then + return nil, err, tbl_concat(chunks) -- Return any data so far. + end + if chunk then + chunks[c] = chunk + c = c + 1 + end + until not chunk + + return tbl_concat(chunks) +end + + +local function _trailer_reader(sock) + return co_wrap(function() + co_yield(_receive_headers(sock)) + end) +end + + +local function _read_trailers(res) + local reader = res.trailer_reader + if not reader then + return nil, "no trailers" + end + + local trailers = reader() + setmetatable(res.headers, { __index = trailers }) +end + + +local function _send_body(sock, body) + if type(body) == "function" then + repeat + local chunk, err, partial = body() + + if chunk then + local ok, err = sock:send(chunk) + + if not ok then + return nil, err + end + elseif err ~= nil then + return nil, err, partial + end + + until chunk == nil + elseif body ~= nil then + local bytes, err = sock:send(body) + + if not bytes then + return nil, err + end + end + return true, nil +end + + +local function _handle_continue(sock, body) + local status, version, reason, err = _receive_status(sock) --luacheck: no unused + if not status then + return nil, nil, err + end + + -- Only send body if we receive a 100 Continue + if status == 100 then + local ok, err = sock:receive("*l") -- Read carriage return + if not ok then + return nil, nil, err + end + _send_body(sock, body) + end + return status, version, err +end + + +function _M.send_request(self, params) + -- Apply defaults + setmetatable(params, { __index = DEFAULT_PARAMS }) + + local sock = self.sock + local body = params.body + local headers = http_headers.new() + + -- We assign one-by-one so that the metatable can handle case insensitivity + -- for us. You can blame the spec for this inefficiency. + local params_headers = params.headers or {} + for k, v in pairs(params_headers) do + headers[k] = v + end + + if not headers["Proxy-Authorization"] then + -- TODO: next major, change this to always override the provided + -- header. Can't do that yet because it would be breaking. + -- The connect method uses self.http_proxy_auth in the poolname so + -- that should be leading. + headers["Proxy-Authorization"] = self.http_proxy_auth + end + + -- Ensure we have appropriate message length or encoding. + do + local is_chunked = transfer_encoding_is_chunked(headers) + + if is_chunked then + -- If we have both Transfer-Encoding and Content-Length we MUST + -- drop the Content-Length, to help prevent request smuggling. + -- https://tools.ietf.org/html/rfc7230#section-3.3.3 + headers["Content-Length"] = nil + + elseif not headers["Content-Length"] then + -- A length was not given, try to calculate one. + + local body_type = type(body) + + if body_type == "function" then + return nil, "Request body is a function but a length or chunked encoding is not specified" + + elseif body_type == "table" then + local length = 0 + for _, v in ipairs(body) do + length = length + #tostring(v) + end + headers["Content-Length"] = length + + elseif body == nil and EXPECTING_BODY[str_upper(params.method)] then + headers["Content-Length"] = 0 + + elseif body ~= nil then + headers["Content-Length"] = #tostring(body) + end + end + end + + if not headers["Host"] then + if (str_sub(self.host, 1, 5) == "unix:") then + return nil, "Unable to generate a useful Host header for a unix domain socket. Please provide one." + end + -- If we have a port (i.e. not connected to a unix domain socket), and this + -- port is non-standard, append it to the Host header. + if self.port then + if self.ssl and self.port ~= 443 then + headers["Host"] = self.host .. ":" .. self.port + elseif not self.ssl and self.port ~= 80 then + headers["Host"] = self.host .. ":" .. self.port + else + headers["Host"] = self.host + end + else + headers["Host"] = self.host + end + end + if not headers["User-Agent"] then + headers["User-Agent"] = _M._USER_AGENT + end + if params.version == 1.0 and not headers["Connection"] then + headers["Connection"] = "Keep-Alive" + end + + params.headers = headers + + -- Format and send request + local req = _format_request(self, params) + if DEBUG then ngx_log(ngx_DEBUG, "\n", req) end + local bytes, err = sock:send(req) + + if not bytes then + return nil, err + end + + -- Send the request body, unless we expect: continue, in which case + -- we handle this as part of reading the response. + if headers["Expect"] ~= "100-continue" then + local ok, err, partial = _send_body(sock, body) + if not ok then + return nil, err, partial + end + end + + return true +end + + +function _M.read_response(self, params) + local sock = self.sock + + local status, version, reason, err + + -- If we expect: continue, we need to handle this, sending the body if allowed. + -- If we don't get 100 back, then status is the actual status. + if params.headers["Expect"] == "100-continue" then + local _status, _version, _err = _handle_continue(sock, params.body) + if not _status then + return nil, _err + elseif _status ~= 100 then + status, version, err = _status, _version, _err -- luacheck: no unused + end + end + + -- Just read the status as normal. + if not status then + status, version, reason, err = _receive_status(sock) + if not status then + return nil, err + end + end + + + local res_headers, err = _receive_headers(sock) + if not res_headers then + return nil, err + end + + -- keepalive is true by default. Determine if this is correct or not. + local ok, connection = pcall(str_lower, res_headers["Connection"]) + if ok then + if (version == 1.1 and str_find(connection, "close", 1, true)) or + (version == 1.0 and not str_find(connection, "keep-alive", 1, true)) then + self.keepalive = false + end + else + -- no connection header + if version == 1.0 then + self.keepalive = false + end + end + + local body_reader = _no_body_reader + local trailer_reader, err + local has_body = false + + -- Receive the body_reader + if _should_receive_body(params.method, status) then + has_body = true + + if version == 1.1 and transfer_encoding_is_chunked(res_headers) then + body_reader, err = _chunked_body_reader(sock) + else + local ok, length = pcall(tonumber, res_headers["Content-Length"]) + if not ok then + -- No content-length header, read until connection is closed by server + length = nil + end + + body_reader, err = _body_reader(sock, length) + end + end + + if res_headers["Trailer"] then + trailer_reader, err = _trailer_reader(sock) + end + + if err then + return nil, err + else + return { + status = status, + reason = reason, + headers = res_headers, + has_body = has_body, + body_reader = body_reader, + read_body = _read_body, + trailer_reader = trailer_reader, + read_trailers = _read_trailers, + } + end +end + + +function _M.request(self, params) + params = tbl_copy(params) -- Take by value + local res, err = self:send_request(params) + if not res then + return res, err + else + return self:read_response(params) + end +end + + +function _M.request_pipeline(self, requests) + requests = tbl_copy(requests) -- Take by value + + for _, params in ipairs(requests) do + if params.headers and params.headers["Expect"] == "100-continue" then + return nil, "Cannot pipeline request specifying Expect: 100-continue" + end + + local res, err = self:send_request(params) + if not res then + return res, err + end + end + + local responses = {} + for i, params in ipairs(requests) do + responses[i] = setmetatable({ + params = params, + response_read = false, + }, { + -- Read each actual response lazily, at the point the user tries + -- to access any of the fields. + __index = function(t, k) + local res, err + if t.response_read == false then + res, err = _M.read_response(self, t.params) + t.response_read = true + + if not res then + ngx_log(ngx_ERR, err) + else + for rk, rv in pairs(res) do + t[rk] = rv + end + end + end + return rawget(t, k) + end, + }) + end + return responses +end + + +function _M.request_uri(self, uri, params) + params = tbl_copy(params or {}) -- Take by value + if self.proxy_opts then + params.proxy_opts = tbl_copy(self.proxy_opts or {}) + end + + do + local parsed_uri, err = self:parse_uri(uri, false) + if not parsed_uri then + return nil, err + end + + local path, query + params.scheme, params.host, params.port, path, query = unpack(parsed_uri) + params.path = params.path or path + params.query = params.query or query + params.ssl_server_name = params.ssl_server_name or params.host + end + + do + local proxy_auth = (params.headers or {})["Proxy-Authorization"] + if proxy_auth and params.proxy_opts then + params.proxy_opts.https_proxy_authorization = proxy_auth + params.proxy_opts.http_proxy_authorization = proxy_auth + end + end + + local ok, err = self:connect(params) + if not ok then + return nil, err + end + + local res, err = self:request(params) + if not res then + self:close() + return nil, err + end + + local body, err = res:read_body() + if not body then + self:close() + return nil, err + end + + res.body = body + + if params.keepalive == false then + local ok, err = self:close() + if not ok then + ngx_log(ngx_ERR, err) + end + + else + local ok, err = self:set_keepalive(params.keepalive_timeout, params.keepalive_pool) + if not ok then + ngx_log(ngx_ERR, err) + end + + end + + return res, nil +end + + +function _M.get_client_body_reader(_, chunksize, sock) + chunksize = chunksize or 65536 + + if not sock then + local ok, err + ok, sock, err = pcall(ngx_req_socket) + + if not ok then + return nil, sock -- pcall err + end + + if not sock then + if err == "no body" then + return nil + else + return nil, err + end + end + end + + local headers = ngx_req_get_headers() + local length = headers.content_length + if length then + return _body_reader(sock, tonumber(length), chunksize) + elseif transfer_encoding_is_chunked(headers) then + -- Not yet supported by ngx_lua but should just work... + return _chunked_body_reader(sock, chunksize) + else + return nil + end +end + + +function _M.set_proxy_options(self, opts) + -- TODO: parse and cache these options, instead of parsing them + -- on each request over and over again (lru-cache on module level) + self.proxy_opts = tbl_copy(opts) -- Take by value +end + + +function _M.get_proxy_uri(self, scheme, host) + if not self.proxy_opts then + return nil + end + + -- Check if the no_proxy option matches this host. Implementation adapted + -- from lua-http library (https://github.com/daurnimator/lua-http) + if self.proxy_opts.no_proxy then + if self.proxy_opts.no_proxy == "*" then + -- all hosts are excluded + return nil + end + + local no_proxy_set = {} + -- wget allows domains in no_proxy list to be prefixed by "." + -- e.g. no_proxy=.mit.edu + for host_suffix in ngx_re_gmatch(self.proxy_opts.no_proxy, "\\.?([^,]+)", "jo") do + no_proxy_set[host_suffix[1]] = true + end + + -- From curl docs: + -- matched as either a domain which contains the hostname, or the + -- hostname itself. For example local.com would match local.com, + -- local.com:80, and www.local.com, but not www.notlocal.com. + -- + -- Therefore, we keep stripping subdomains from the host, compare + -- them to the ones in the no_proxy list and continue until we find + -- a match or until there's only the TLD left + repeat + if no_proxy_set[host] then + return nil + end + + -- Strip the next level from the domain and check if that one + -- is on the list + host = ngx_re_sub(host, "^[^.]+\\.", "", "jo") + until not ngx_re_find(host, "\\.", "jo") + end + + if scheme == "http" and self.proxy_opts.http_proxy then + return self.proxy_opts.http_proxy + end + + if scheme == "https" and self.proxy_opts.https_proxy then + return self.proxy_opts.https_proxy + end + + return nil +end + + +-- ---------------------------------------------------------------------------- +-- The following functions are considered DEPRECATED and may be REMOVED in +-- future releases. Please see the notes in `README.md`. +-- ---------------------------------------------------------------------------- + +function _M.ssl_handshake(self, ...) + ngx_log(ngx_DEBUG, "Use of deprecated function `ssl_handshake`") + + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + self.ssl = true + + return sock:sslhandshake(...) +end + + +function _M.connect_proxy(self, proxy_uri, scheme, host, port, proxy_authorization) + ngx_log(ngx_DEBUG, "Use of deprecated function `connect_proxy`") + + -- Parse the provided proxy URI + local parsed_proxy_uri, err = self:parse_uri(proxy_uri, false) + if not parsed_proxy_uri then + return nil, err + end + + -- Check that the scheme is http (https is not supported for + -- connections between the client and the proxy) + local proxy_scheme = parsed_proxy_uri[1] + if proxy_scheme ~= "http" then + return nil, "protocol " .. proxy_scheme .. " not supported for proxy connections" + end + + -- Make the connection to the given proxy + local proxy_host, proxy_port = parsed_proxy_uri[2], parsed_proxy_uri[3] + local c, err = self:tcp_only_connect(proxy_host, proxy_port) + if not c then + return nil, err + end + + if scheme == "https" then + -- Make a CONNECT request to create a tunnel to the destination through + -- the proxy. The request-target and the Host header must be in the + -- authority-form of RFC 7230 Section 5.3.3. See also RFC 7231 Section + -- 4.3.6 for more details about the CONNECT request + local destination = host .. ":" .. port + local res, err = self:request({ + method = "CONNECT", + path = destination, + headers = { + ["Host"] = destination, + ["Proxy-Authorization"] = proxy_authorization, + } + }) + + if not res then + return nil, err + end + + if res.status < 200 or res.status > 299 then + return nil, "failed to establish a tunnel through a proxy: " .. res.status + end + end + + return c, nil +end + + +function _M.proxy_request(self, chunksize) + ngx_log(ngx_DEBUG, "Use of deprecated function `proxy_request`") + + return self:request({ + method = ngx_req_get_method(), + path = ngx_re_gsub(ngx_var.uri, "\\s", "%20", "jo") .. ngx_var.is_args .. (ngx_var.query_string or ""), + body = self:get_client_body_reader(chunksize), + headers = ngx_req_get_headers(), + }) +end + + +function _M.proxy_response(_, response, chunksize) + ngx_log(ngx_DEBUG, "Use of deprecated function `proxy_response`") + + if not response then + ngx_log(ngx_ERR, "no response provided") + return + end + + ngx.status = response.status + + -- Filter out hop-by-hop headeres + for k, v in pairs(response.headers) do + if not HOP_BY_HOP_HEADERS[str_lower(k)] then + ngx_header[k] = v + end + end + + local reader = response.body_reader + + repeat + local chunk, ok, read_err, print_err + + chunk, read_err = reader(chunksize) + if read_err then + ngx_log(ngx_ERR, read_err) + end + + if chunk then + ok, print_err = ngx_print(chunk) + if not ok then + ngx_log(ngx_ERR, print_err) + end + end + + if read_err or print_err then + break + end + until not chunk +end + + +return _M diff --git a/server/resty/http_connect.lua b/server/resty/http_connect.lua new file mode 100644 index 0000000..18a74b1 --- /dev/null +++ b/server/resty/http_connect.lua @@ -0,0 +1,274 @@ +local ngx_re_gmatch = ngx.re.gmatch +local ngx_re_sub = ngx.re.sub +local ngx_re_find = ngx.re.find +local ngx_log = ngx.log +local ngx_WARN = ngx.WARN + +--[[ +A connection function that incorporates: + - tcp connect + - ssl handshake + - http proxy +Due to this it will be better at setting up a socket pool where connections can +be kept alive. + + +Call it with a single options table as follows: + +client:connect { + scheme = "https" -- scheme to use, or nil for unix domain socket + host = "myhost.com", -- target machine, or a unix domain socket + port = nil, -- port on target machine, will default to 80/443 based on scheme + pool = nil, -- connection pool name, leave blank! this function knows best! + pool_size = nil, -- options as per: https://github.com/openresty/lua-nginx-module#tcpsockconnect + backlog = nil, + + -- ssl options as per: https://github.com/openresty/lua-nginx-module#tcpsocksslhandshake + ssl_reused_session = nil + ssl_server_name = nil, + ssl_send_status_req = nil, + ssl_verify = true, -- NOTE: defaults to true + ctx = nil, -- NOTE: not supported + + -- mTLS options (experimental!) + -- + -- !!! IMPORTANT !!! These options require support for mTLS in cosockets, + -- which is currently only available in the following unmerged PRs. + -- + -- * https://github.com/openresty/lua-nginx-module/pull/1602 + -- * https://github.com/openresty/lua-resty-core/pull/278 + -- + -- The details of this feature may change. You have been warned! + -- + ssl_client_cert = nil, + ssl_client_priv_key = nil, + + proxy_opts, -- proxy opts, defaults to global proxy options +} +]] +local function connect(self, options) + local sock = self.sock + if not sock then + return nil, "not initialized" + end + + local ok, err + + local request_scheme = options.scheme + local request_host = options.host + local request_port = options.port + + local poolname = options.pool + local pool_size = options.pool_size + local backlog = options.backlog + + if request_scheme and not request_port then + request_port = (request_scheme == "https" and 443 or 80) + elseif request_port and not request_scheme then + return nil, "'scheme' is required when providing a port" + end + + -- ssl settings + local ssl, ssl_reused_session, ssl_server_name + local ssl_verify, ssl_send_status_req, ssl_client_cert, ssl_client_priv_key + if request_scheme == "https" then + ssl = true + ssl_reused_session = options.ssl_reused_session + ssl_server_name = options.ssl_server_name + ssl_send_status_req = options.ssl_send_status_req + ssl_verify = true -- default + if options.ssl_verify == false then + ssl_verify = false + end + ssl_client_cert = options.ssl_client_cert + ssl_client_priv_key = options.ssl_client_priv_key + end + + -- proxy related settings + local proxy, proxy_uri, proxy_authorization, proxy_host, proxy_port, path_prefix + proxy = options.proxy_opts or self.proxy_opts + + if proxy then + if request_scheme == "https" then + proxy_uri = proxy.https_proxy + proxy_authorization = proxy.https_proxy_authorization + else + proxy_uri = proxy.http_proxy + proxy_authorization = proxy.http_proxy_authorization + -- When a proxy is used, the target URI must be in absolute-form + -- (RFC 7230, Section 5.3.2.). That is, it must be an absolute URI + -- to the remote resource with the scheme, host and an optional port + -- in place. + -- + -- Since _format_request() constructs the request line by concatenating + -- params.path and params.query together, we need to modify the path + -- to also include the scheme, host and port so that the final form + -- in conformant to RFC 7230. + path_prefix = "http://" .. request_host .. (request_port == 80 and "" or (":" .. request_port)) + end + if not proxy_uri then + proxy = nil + proxy_authorization = nil + path_prefix = nil + end + end + + if proxy and proxy.no_proxy then + -- Check if the no_proxy option matches this host. Implementation adapted + -- from lua-http library (https://github.com/daurnimator/lua-http) + if proxy.no_proxy == "*" then + -- all hosts are excluded + proxy = nil + + else + local host = request_host + local no_proxy_set = {} + -- wget allows domains in no_proxy list to be prefixed by "." + -- e.g. no_proxy=.mit.edu + for host_suffix in ngx_re_gmatch(proxy.no_proxy, "\\.?([^,]+)") do + no_proxy_set[host_suffix[1]] = true + end + + -- From curl docs: + -- matched as either a domain which contains the hostname, or the + -- hostname itself. For example local.com would match local.com, + -- local.com:80, and www.local.com, but not www.notlocal.com. + -- + -- Therefore, we keep stripping subdomains from the host, compare + -- them to the ones in the no_proxy list and continue until we find + -- a match or until there's only the TLD left + repeat + if no_proxy_set[host] then + proxy = nil + proxy_uri = nil + proxy_authorization = nil + break + end + + -- Strip the next level from the domain and check if that one + -- is on the list + host = ngx_re_sub(host, "^[^.]+\\.", "") + until not ngx_re_find(host, "\\.") + end + end + + if proxy then + local proxy_uri_t + proxy_uri_t, err = self:parse_uri(proxy_uri) + if not proxy_uri_t then + return nil, "uri parse error: ", err + end + + local proxy_scheme = proxy_uri_t[1] + if proxy_scheme ~= "http" then + return nil, "protocol " .. tostring(proxy_scheme) .. + " not supported for proxy connections" + end + proxy_host = proxy_uri_t[2] + proxy_port = proxy_uri_t[3] + end + + -- construct a poolname unique within proxy and ssl info + if not poolname then + poolname = (request_scheme or "") + .. ":" .. request_host + .. ":" .. tostring(request_port) + .. ":" .. tostring(ssl) + .. ":" .. (ssl_server_name or "") + .. ":" .. tostring(ssl_verify) + .. ":" .. (proxy_uri or "") + .. ":" .. (request_scheme == "https" and proxy_authorization or "") + -- in the above we only add the 'proxy_authorization' as part of the poolname + -- when the request is https. Because in that case the CONNECT request (which + -- carries the authorization header) is part of the connect procedure, whereas + -- with a plain http request the authorization is part of the actual request. + end + + -- do TCP level connection + local tcp_opts = { pool = poolname, pool_size = pool_size, backlog = backlog } + if proxy then + -- proxy based connection + ok, err = sock:connect(proxy_host, proxy_port, tcp_opts) + if not ok then + return nil, "failed to connect to: " .. (proxy_host or "") .. + ":" .. (proxy_port or "") .. + ": ", err + end + + if ssl and sock:getreusedtimes() == 0 then + -- Make a CONNECT request to create a tunnel to the destination through + -- the proxy. The request-target and the Host header must be in the + -- authority-form of RFC 7230 Section 5.3.3. See also RFC 7231 Section + -- 4.3.6 for more details about the CONNECT request + local destination = request_host .. ":" .. request_port + local res + res, err = self:request({ + method = "CONNECT", + path = destination, + headers = { + ["Host"] = destination, + ["Proxy-Authorization"] = proxy_authorization, + } + }) + + if not res then + return nil, "failed to issue CONNECT to proxy:", err + end + + if res.status < 200 or res.status > 299 then + return nil, "failed to establish a tunnel through a proxy: " .. res.status + end + end + + elseif not request_port then + -- non-proxy, without port -> unix domain socket + ok, err = sock:connect(request_host, tcp_opts) + if not ok then + return nil, err + end + + else + -- non-proxy, regular network tcp + ok, err = sock:connect(request_host, request_port, tcp_opts) + if not ok then + return nil, err + end + end + + local ssl_session + -- Now do the ssl handshake + if ssl and sock:getreusedtimes() == 0 then + + -- Experimental mTLS support + if ssl_client_cert and ssl_client_priv_key then + if type(sock.setclientcert) ~= "function" then + ngx_log(ngx_WARN, "cannot use SSL client cert and key without mTLS support") + + else + -- currently no return value + ok, err = sock:setclientcert(ssl_client_cert, ssl_client_priv_key) + if not ok then + ngx_log(ngx_WARN, "could not set client certificate: ", err) + end + end + end + + ssl_session, err = sock:sslhandshake(ssl_reused_session, ssl_server_name, ssl_verify, ssl_send_status_req) + if not ssl_session then + self:close() + return nil, err + end + end + + self.host = request_host + self.port = request_port + self.keepalive = true + self.ssl = ssl + -- set only for http, https has already been handled + self.http_proxy_auth = request_scheme ~= "https" and proxy_authorization or nil + self.path_prefix = path_prefix + + return true, nil, ssl_session +end + +return connect diff --git a/server/resty/http_headers.lua b/server/resty/http_headers.lua new file mode 100644 index 0000000..97e8157 --- /dev/null +++ b/server/resty/http_headers.lua @@ -0,0 +1,44 @@ +local rawget, rawset, setmetatable = + rawget, rawset, setmetatable + +local str_lower = string.lower + +local _M = { + _VERSION = '0.17.0-beta.1', +} + + +-- Returns an empty headers table with internalised case normalisation. +function _M.new() + local mt = { + normalised = {}, + } + + mt.__index = function(t, k) + return rawget(t, mt.normalised[str_lower(k)]) + end + + mt.__newindex = function(t, k, v) + local k_normalised = str_lower(k) + + -- First time seeing this header field? + if not mt.normalised[k_normalised] then + -- Create a lowercased entry in the metatable proxy, with the value + -- of the given field case + mt.normalised[k_normalised] = k + + -- Set the header using the given field case + rawset(t, k, v) + else + -- We're being updated just with a different field case. Use the + -- normalised metatable proxy to give us the original key case, and + -- perorm a rawset() to update the value. + rawset(t, mt.normalised[k_normalised], v) + end + end + + return setmetatable({}, mt) +end + + +return _M diff --git a/server/resty/jwt-validators.lua b/server/resty/jwt-validators.lua new file mode 100644 index 0000000..df99418 --- /dev/null +++ b/server/resty/jwt-validators.lua @@ -0,0 +1,412 @@ +local _M = { _VERSION = "0.2.3" } + +--[[ + This file defines "validators" to be used in validating a spec. A "validator" is simply a function with + a signature that matches: + + function(val, claim, jwt_json) + + This function returns either true or false. If a validator needs to give more information on why it failed, + then it can also raise an error (which will be used in the "reason" part of the validated jwt_obj). If a + validator returns nil, then it is assumed to have passed (same as returning true) and that you just forgot + to actually return a value. + + There is a special claim name of "__jwt" that can be used to validate the entire jwt_obj. + + "val" is the value being tested. It may be nil if the claim doesn't exist in the jwt_obj. If the function + is being called for the "__jwt" claim, then "val" will contain a deep clone of the full jwt object. + + "claim" is the claim that is being tested. It is passed in just in case a validator needs to do additional + checks. It will be the string "__jwt" if the validator is being called for the entire jwt_object. + + "jwt_json" is a json-encoded representation of the full object that is being tested. It will never be nil, + and can always be decoded using cjson.decode(jwt_json). +]]-- + + +--[[ + A function which will define a validator. It creates both "opt_" and required (non-"opt_") + versions. The function that is passed in is the *optional* version. +]]-- +local function define_validator(name, fx) + _M["opt_" .. name] = fx + _M[name] = function(...) return _M.chain(_M.required(), fx(...)) end +end + +-- Validation messages +local messages = { + nil_validator = "Cannot create validator for nil %s.", + wrong_type_validator = "Cannot create validator for non-%s %s.", + empty_table_validator = "Cannot create validator for empty table %s.", + wrong_table_type_validator = "Cannot create validator for non-%s table %s.", + required_claim = "'%s' claim is required.", + wrong_type_claim = "'%s' is malformed. Expected to be a %s.", + missing_claim = "Missing one of claims - [ %s ]." +} + +-- Local function to make sure that a value is non-nil or raises an error +local function ensure_not_nil(v, e, ...) + return v ~= nil and v or error(string.format(e, ...), 0) +end + +-- Local function to make sure that a value is the given type +local function ensure_is_type(v, t, e, ...) + return type(v) == t and v or error(string.format(e, ...), 0) +end + +-- Local function to make sure that a value is a (non-empty) table +local function ensure_is_table(v, e, ...) + ensure_is_type(v, "table", e, ...) + return ensure_not_nil(next(v), e, ...) +end + +-- Local function to make sure all entries in the table are the given type +local function ensure_is_table_type(v, t, e, ...) + if v ~= nil then + ensure_is_table(v, e, ...) + for _,val in ipairs(v) do + ensure_is_type(val, t, e, ...) + end + end + return v +end + +-- Local function to ensure that a number is non-negative (positive or 0) +local function ensure_is_non_negative(v, e, ...) + if v ~= nil then + ensure_is_type(v, "number", e, ...) + if v >= 0 then + return v + else + error(string.format(e, ...), 0) + end + end +end + +-- A local function which returns simple equality +local function equality_function(val, check) + return val == check +end + +-- A local function which returns string match +local function string_match_function(val, pattern) + return string.match(val, pattern) ~= nil +end + +--[[ + A local function which returns truth on existence of check in vals. + Adopted from auth0/nginx-jwt table_contains by @twistedstream +]]-- +local function table_contains_function(vals, check) + for _, val in pairs(vals) do + if val == check then return true end + end + return false +end + + +-- A local function which returns numeric greater than comparison +local function greater_than_function(val, check) + return val > check +end + +-- A local function which returns numeric greater than or equal comparison +local function greater_than_or_equal_function(val, check) + return val >= check +end + +-- A local function which returns numeric less than comparison +local function less_than_function(val, check) + return val < check +end + +-- A local function which returns numeric less than or equal comparison +local function less_than_or_equal_function(val, check) + return val <= check +end + + +--[[ + Returns a validator that chains the given functions together, one after + another - as long as they keep passing their checks. +]]-- +function _M.chain(...) + local chain_functions = {...} + for _, fx in ipairs(chain_functions) do + ensure_is_type(fx, "function", messages.wrong_type_validator, "function", "chain_function") + end + + return function(val, claim, jwt_json) + for _, fx in ipairs(chain_functions) do + if fx(val, claim, jwt_json) == false then + return false + end + end + return true + end +end + +--[[ + Returns a validator that returns false if a value doesn't exist. If + the value exists and a chain_function is specified, then the value of + chain_function(val, claim, jwt_json) + will be returned, otherwise, true will be returned. This allows for + specifying that a value is both required *and* it must match some + additional check. This function will be used in the "required_*" shortcut + functions for simplification. +]]-- +function _M.required(chain_function) + if chain_function ~= nil then + return _M.chain(_M.required(), chain_function) + end + + return function(val, claim, jwt_json) + ensure_not_nil(val, messages.required_claim, claim) + return true + end +end + +--[[ + Returns a validator which errors with a message if *NONE* of the given claim + keys exist. It is expected that this function is used against a full jwt object. + The claim_keys must be a non-empty table of strings. +]]-- +function _M.require_one_of(claim_keys) + ensure_not_nil(claim_keys, messages.nil_validator, "claim_keys") + ensure_is_type(claim_keys, "table", messages.wrong_type_validator, "table", "claim_keys") + ensure_is_table(claim_keys, messages.empty_table_validator, "claim_keys") + ensure_is_table_type(claim_keys, "string", messages.wrong_table_type_validator, "string", "claim_keys") + + return function(val, claim, jwt_json) + ensure_is_type(val, "table", messages.wrong_type_claim, claim, "table") + ensure_is_type(val.payload, "table", messages.wrong_type_claim, claim .. ".payload", "table") + + for i, v in ipairs(claim_keys) do + if val.payload[v] ~= nil then return true end + end + + error(string.format(messages.missing_claim, table.concat(claim_keys, ", ")), 0) + end +end + +--[[ + Returns a validator that checks if the result of calling the given function for + the tested value and the check value returns true. The value of check_val and + check_function cannot be nil. The optional name is used for error messages and + defaults to "check_value". The optional check_type is used to make sure that + the check type matches and defaults to type(check_val). The first parameter + passed to check_function will *never* be nil (check succeeds if value is nil). + Use the required version to fail on nil. If the check_function raises an + error, that will be appended to the error message. +]]-- +define_validator("check", function(check_val, check_function, name, check_type) + name = name or "check_val" + ensure_not_nil(check_val, messages.nil_validator, name) + + ensure_not_nil(check_function, messages.nil_validator, "check_function") + ensure_is_type(check_function, "function", messages.wrong_type_validator, "function", "check_function") + + check_type = check_type or type(check_val) + return function(val, claim, jwt_json) + if val == nil then return true end + + ensure_is_type(val, check_type, messages.wrong_type_claim, claim, check_type) + return check_function(val, check_val) + end +end) + + +--[[ + Returns a validator that checks if a value exactly equals the given check_value. + If the value is nil, then this check succeeds. The value of check_val cannot be + nil. +]]-- +define_validator("equals", function(check_val) + return _M.opt_check(check_val, equality_function, "check_val") +end) + + +--[[ + Returns a validator that checks if a value matches the given pattern. The value + of pattern must be a string. +]]-- +define_validator("matches", function (pattern) + ensure_is_type(pattern, "string", messages.wrong_type_validator, "string", "pattern") + return _M.opt_check(pattern, string_match_function, "pattern", "string") +end) + + +--[[ + Returns a validator which calls the given function for each of the given values + and the tested value. If any of these calls return true, then this function + returns true. The value of check_values must be a non-empty table with all the + same types, and the value of check_function must not be nil. The optional name + is used for error messages and defaults to "check_values". The optional + check_type is used to make sure that the check type matches and defaults to + type(check_values[1]) - the table type. +]]-- +define_validator("any_of", function(check_values, check_function, name, check_type, table_type) + name = name or "check_values" + ensure_not_nil(check_values, messages.nil_validator, name) + ensure_is_type(check_values, "table", messages.wrong_type_validator, "table", name) + ensure_is_table(check_values, messages.empty_table_validator, name) + + table_type = table_type or type(check_values[1]) + ensure_is_table_type(check_values, table_type, messages.wrong_table_type_validator, table_type, name) + + ensure_not_nil(check_function, messages.nil_validator, "check_function") + ensure_is_type(check_function, "function", messages.wrong_type_validator, "function", "check_function") + + check_type = check_type or table_type + return _M.opt_check(check_values, function(v1, v2) + for i, v in ipairs(v2) do + if check_function(v1, v) then return true end + end + return false + end, name, check_type) +end) + + +--[[ + Returns a validator that checks if a value exactly equals any of the given values. +]]-- +define_validator("equals_any_of", function(check_values) + return _M.opt_any_of(check_values, equality_function, "check_values") +end) + + +--[[ + Returns a validator that checks if a value matches any of the given patterns. +]]-- +define_validator("matches_any_of", function(patterns) + return _M.opt_any_of(patterns, string_match_function, "patterns", "string", "string") +end) + +--[[ + Returns a validator that checks if a value of expected type string exists in any of the given values. + The value of check_values must be a non-empty table with all the same types. + The optional name is used for error messages and defaults to "check_values". +]]-- +define_validator("contains_any_of", function(check_values, name) + return _M.opt_any_of(check_values, table_contains_function, name, "table", "string") +end) + +--[[ + Returns a validator that checks how a value compares (numerically) to a given + check_value. The value of check_val cannot be nil and must be a number. +]]-- +define_validator("greater_than", function(check_val) + ensure_is_type(check_val, "number", messages.wrong_type_validator, "number", "check_val") + return _M.opt_check(check_val, greater_than_function, "check_val", "number") +end) +define_validator("greater_than_or_equal", function(check_val) + ensure_is_type(check_val, "number", messages.wrong_type_validator, "number", "check_val") + return _M.opt_check(check_val, greater_than_or_equal_function, "check_val", "number") +end) +define_validator("less_than", function(check_val) + ensure_is_type(check_val, "number", messages.wrong_type_validator, "number", "check_val") + return _M.opt_check(check_val, less_than_function, "check_val", "number") +end) +define_validator("less_than_or_equal", function(check_val) + ensure_is_type(check_val, "number", messages.wrong_type_validator, "number", "check_val") + return _M.opt_check(check_val, less_than_or_equal_function, "check_val", "number") +end) + + +--[[ + A function to set the leeway (in seconds) used for is_not_before and is_not_expired. The + default is to use 0 seconds +]]-- +local system_leeway = 0 +function _M.set_system_leeway(leeway) + ensure_is_type(leeway, "number", "leeway must be a non-negative number") + ensure_is_non_negative(leeway, "leeway must be a non-negative number") + system_leeway = leeway +end + + +--[[ + A function to set the system clock used for is_not_before and is_not_expired. The + default is to use ngx.now +]]-- +local system_clock = ngx.now +function _M.set_system_clock(clock) + ensure_is_type(clock, "function", "clock must be a function") + -- Check that clock returns the correct value + local t = clock() + ensure_is_type(t, "number", "clock function must return a non-negative number") + ensure_is_non_negative(t, "clock function must return a non-negative number") + system_clock = clock +end + +-- Local helper function for date validation +local function validate_is_date(val, claim, jwt_json) + ensure_is_non_negative(val, messages.wrong_type_claim, claim, "positive numeric value") + return true +end + +-- Local helper for date formatting +local function format_date_on_error(date_check_function, error_msg) + ensure_is_type(date_check_function, "function", messages.wrong_type_validator, "function", "date_check_function") + ensure_is_type(error_msg, "string", messages.wrong_type_validator, "string", error_msg) + return function(val, claim, jwt_json) + local ret = date_check_function(val, claim, jwt_json) + if ret == false then + error(string.format("'%s' claim %s %s", claim, error_msg, ngx.http_time(val)), 0) + end + return true + end +end + +--[[ + Returns a validator that checks if the current time is not before the tested value + within the system's leeway. This means that: + val <= (system_clock() + system_leeway). +]]-- +define_validator("is_not_before", function() + return format_date_on_error( + _M.chain(validate_is_date, + function(val) + return val and less_than_or_equal_function(val, (system_clock() + system_leeway)) + end), + "not valid until" + ) +end) + + +--[[ + Returns a validator that checks if the current time is not equal to or after the + tested value within the system's leeway. This means that: + val > (system_clock() - system_leeway). +]]-- +define_validator("is_not_expired", function() + return format_date_on_error( + _M.chain(validate_is_date, + function(val) + return val and greater_than_function(val, (system_clock() - system_leeway)) + end), + "expired at" + ) +end) + +--[[ + Returns a validator that checks if the current time is the same as the tested value + within the system's leeway. This means that: + val >= (system_clock() - system_leeway) and val <= (system_clock() + system_leeway). +]]-- +define_validator("is_at", function() + local now = system_clock() + return format_date_on_error( + _M.chain(validate_is_date, + function(val) + local now = system_clock() + return val and + greater_than_or_equal_function(val, now - system_leeway) and + less_than_or_equal_function(val, now + system_leeway) + end), + "is only valid at" + ) +end) + + +return _M diff --git a/server/resty/jwt.lua b/server/resty/jwt.lua new file mode 100644 index 0000000..accba11 --- /dev/null +++ b/server/resty/jwt.lua @@ -0,0 +1,959 @@ +local cjson = require "cjson.safe" + +local evp = require "resty.evp" +local hmac = require "resty.hmac" +local resty_random = require "resty.random" +local cipher = require "resty.openssl.cipher" + +local _M = { _VERSION = "0.2.3" } + +local mt = { + __index = _M +} + +local string_rep = string.rep +local string_format = string.format +local string_sub = string.sub +local string_char = string.char +local table_concat = table.concat +local ngx_encode_base64 = ngx.encode_base64 +local ngx_decode_base64 = ngx.decode_base64 +local cjson_encode = cjson.encode +local cjson_decode = cjson.decode +local tostring = tostring +local error = error +local ipairs = ipairs +local type = type +local pcall = pcall +local assert = assert +local setmetatable = setmetatable +local pairs = pairs + +-- define string constants to avoid string garbage collection +local str_const = { + invalid_jwt= "invalid jwt string", + regex_join_msg = "%s.%s", + regex_join_delim = "([^%s]+)", + regex_split_dot = "%.", + regex_jwt_join_str = "%s.%s.%s", + raw_underscore = "raw_", + dash = "-", + empty = "", + dotdot = "..", + table = "table", + plus = "+", + equal = "=", + underscore = "_", + slash = "/", + header = "header", + typ = "typ", + JWT = "JWT", + JWE = "JWE", + payload = "payload", + signature = "signature", + encrypted_key = "encrypted_key", + alg = "alg", + enc = "enc", + kid = "kid", + exp = "exp", + nbf = "nbf", + iss = "iss", + full_obj = "__jwt", + x5c = "x5c", + x5u = 'x5u', + HS256 = "HS256", + HS512 = "HS512", + RS256 = "RS256", + ES256 = "ES256", + ES512 = "ES512", + RS512 = "RS512", + A128CBC_HS256 = "A128CBC-HS256", + A128CBC_HS256_CIPHER_MODE = "aes-128-cbc", + A256CBC_HS512 = "A256CBC-HS512", + A256CBC_HS512_CIPHER_MODE = "aes-256-cbc", + A256GCM = "A256GCM", + A256GCM_CIPHER_MODE = "aes-256-gcm", + RSA_OAEP_256 = "RSA-OAEP-256", + DIR = "dir", + reason = "reason", + verified = "verified", + number = "number", + string = "string", + funct = "function", + boolean = "boolean", + valid = "valid", + valid_issuers = "valid_issuers", + lifetime_grace_period = "lifetime_grace_period", + require_nbf_claim = "require_nbf_claim", + require_exp_claim = "require_exp_claim", + internal_error = "internal error", + everything_awesome = "everything is awesome~ :p" +} + +-- @function split string +local function split_string(str, delim) + local result = {} + local sep = string_format(str_const.regex_join_delim, delim) + for m in str:gmatch(sep) do + result[#result+1]=m + end + return result +end + +-- @function is nil or boolean +-- @return true if param is nil or true or false; false otherwise +local function is_nil_or_boolean(arg_value) + if arg_value == nil then + return true + end + + if type(arg_value) ~= str_const.boolean then + return false + end + + return true +end + +--@function get the raw part +--@param part_name +--@param jwt_obj +local function get_raw_part(part_name, jwt_obj) + local raw_part = jwt_obj[str_const.raw_underscore .. part_name] + if raw_part == nil then + local part = jwt_obj[part_name] + if part == nil then + error({reason="missing part " .. part_name}) + end + raw_part = _M:jwt_encode(part) + end + return raw_part +end + + +--@function decrypt payload +--@param secret_key to decrypt the payload +--@param encrypted payload +--@param encryption algorithm +--@param iv which was generated while encrypting the payload +--@param aad additional authenticated data (used when gcm mode is used) +--@param auth_tag authenticated tag (used when gcm mode is used) +--@return decrypted payloaf +local function decrypt_payload(secret_key, encrypted_payload, enc, iv_in, aad, auth_tag ) + local decrypted_payload, err + if enc == str_const.A128CBC_HS256 then + local aes_128_cbs_cipher = assert(cipher.new(str_const.A128CBC_HS256_CIPHER_MODE)) + decrypted_payload, err= aes_128_cbs_cipher:decrypt(secret_key, iv_in, encrypted_payload) + elseif enc == str_const.A256CBC_HS512 then + local aes_256_cbs_cipher = assert(cipher.new(str_const.A256CBC_HS512_CIPHER_MODE)) + decrypted_payload, err = aes_256_cbs_cipher:decrypt(secret_key, iv_in, encrypted_payload) + elseif enc == str_const.A256GCM then + local aes_256_gcm_cipher = assert(cipher.new(str_const.A256GCM_CIPHER_MODE)) + decrypted_payload, err = aes_256_gcm_cipher:decrypt(secret_key, iv_in, encrypted_payload, false, aad, auth_tag) + else + return nil, "unsupported enc: " .. enc + end + if not decrypted_payload or err then + return nil, err + end + return decrypted_payload +end + +-- @function encrypt payload using given secret +-- @param secret_key secret key to encrypt +-- @param message data to be encrypted. It could be lua table or string +-- @param enc algorithm to use for encryption +-- @param aad additional authenticated data (used when gcm mode is used) +local function encrypt_payload(secret_key, message, enc, aad ) + + if enc == str_const.A128CBC_HS256 then + local iv_rand = resty_random.bytes(16,true) + local aes_128_cbs_cipher = assert(cipher.new(str_const.A128CBC_HS256_CIPHER_MODE)) + local encrypted = aes_128_cbs_cipher:encrypt(secret_key, iv_rand, message) + return encrypted, iv_rand + + elseif enc == str_const.A256CBC_HS512 then + local iv_rand = resty_random.bytes(16,true) + local aes_256_cbs_cipher = assert(cipher.new(str_const.A256CBC_HS512_CIPHER_MODE)) + local encrypted = aes_256_cbs_cipher:encrypt(secret_key, iv_rand, message) + return encrypted, iv_rand + + elseif enc == str_const.A256GCM then + local iv_rand = resty_random.bytes(12,true) -- 96 bit IV is recommended for efficiency + local aes_256_gcm_cipher = assert(cipher.new(str_const.A256GCM_CIPHER_MODE)) + local encrypted = aes_256_gcm_cipher:encrypt(secret_key, iv_rand, message, false, aad) + local auth_tag = assert(aes_256_gcm_cipher:get_aead_tag()) + return encrypted, iv_rand, auth_tag + + else + return nil, nil , nil, "unsupported enc: " .. enc + end +end + +--@function hmac_digest : generate hmac digest based on key for input message +--@param mac_key +--@param input message +--@return hmac digest +local function hmac_digest(enc, mac_key, message) + if enc == str_const.A128CBC_HS256 then + return hmac:new(mac_key, hmac.ALGOS.SHA256):final(message) + elseif enc == str_const.A256CBC_HS512 then + return hmac:new(mac_key, hmac.ALGOS.SHA512):final(message) + else + error({reason="unsupported enc: " .. enc}) + end +end + +--@function dervice keys: it generates key if null based on encryption algorithm +--@param encryption type +--@param secret key +--@return secret key, mac key and encryption key +local function derive_keys(enc, secret_key) + local mac_key_len, enc_key_len = 16, 16 + + if enc == str_const.A256GCM then + mac_key_len, enc_key_len = 0, 32 -- we need 256 bit key + elseif enc == str_const.A128CBC_HS256 then + mac_key_len, enc_key_len = 16, 16 + elseif enc == str_const.A256CBC_HS512 then + mac_key_len, enc_key_len = 32, 32 + else + error({reason="unsupported payload encryption algorithm :" .. enc}) + end + + local secret_key_len = mac_key_len + enc_key_len + + if not secret_key then + secret_key = resty_random.bytes(secret_key_len, true) + end + + if #secret_key ~= secret_key_len then + error({reason="invalid pre-shared key"}) + end + + local mac_key = string_sub(secret_key, 1, mac_key_len) + local enc_key = string_sub(secret_key, mac_key_len + 1) + return secret_key, mac_key, enc_key +end + +local function get_payload_encoder(self) + return self.payload_encoder or cjson_encode +end + +local function get_payload_decoder(self) + return self.payload_decoder or cjson_decode +end + +--@function parse_jwe +--@param pre-shared key +--@encoded-header +local function parse_jwe(self, preshared_key, encoded_header, encoded_encrypted_key, encoded_iv, encoded_cipher_text, encoded_auth_tag) + + + local header = _M:jwt_decode(encoded_header, true) + if not header then + error({reason="invalid header: " .. encoded_header}) + end + + local alg = header.alg + if alg ~= str_const.DIR and alg ~= str_const.RSA_OAEP_256 then + error({reason="invalid algorithm: " .. alg}) + end + + local key, enc_key + if alg == str_const.DIR then + if not preshared_key then + error({reason="preshared key must not be null"}) + end + key, _, enc_key = derive_keys(header.enc, preshared_key) + elseif alg == str_const.RSA_OAEP_256 then + if not preshared_key then + error({reason="rsa private key must not be null"}) + end + local rsa_decryptor, err = evp.RSADecryptor:new(preshared_key, nil, evp.CONST.RSA_PKCS1_OAEP_PADDING, evp.CONST.SHA256_DIGEST) + if err then + error({reason="failed to create rsa object: ".. err}) + end + local secret_key, err = rsa_decryptor:decrypt(_M:jwt_decode(encoded_encrypted_key)) + if err or not secret_key then + error({reason="failed to decrypt key: " .. err}) + end + key, _, enc_key = derive_keys(header.enc, secret_key) + end + + local cipher_text = _M:jwt_decode(encoded_cipher_text) + local iv = _M:jwt_decode(encoded_iv) + local signature_or_tag = _M:jwt_decode(encoded_auth_tag) + local basic_jwe = { + internal = { + encoded_header = encoded_header, + cipher_text = cipher_text, + key = key, + iv = iv + }, + header = header, + signature = signature_or_tag + } + + local payload, err = decrypt_payload(enc_key, cipher_text, header.enc, iv, encoded_header, signature_or_tag) + if err then + error({reason="failed to decrypt payload: " .. err}) + + else + basic_jwe.payload = get_payload_decoder(self)(payload) + basic_jwe.internal.json_payload=payload + end + return basic_jwe +end + +-- @function parse_jwt +-- @param encoded header +-- @param encoded +-- @param signature +-- @return jwt table +local function parse_jwt(encoded_header, encoded_payload, signature) + local header = _M:jwt_decode(encoded_header, true) + if not header then + error({reason="invalid header: " .. encoded_header}) + end + + local payload = _M:jwt_decode(encoded_payload, true) + if not payload then + error({reason="invalid payload: " .. encoded_payload}) + end + + local basic_jwt = { + raw_header=encoded_header, + raw_payload=encoded_payload, + header=header, + payload=payload, + signature=signature + } + return basic_jwt + +end + +-- @function parse token - this can be JWE or JWT token +-- @param token string +-- @return jwt/jwe tables +local function parse(self, secret, token_str) + local tokens = split_string(token_str, str_const.regex_split_dot) + local num_tokens = #tokens + if num_tokens == 3 then + return parse_jwt(tokens[1], tokens[2], tokens[3]) + elseif num_tokens == 4 then + return parse_jwe(self, secret, tokens[1], nil, tokens[2], tokens[3], tokens[4]) + elseif num_tokens == 5 then + return parse_jwe(self, secret, tokens[1], tokens[2], tokens[3], tokens[4], tokens[5]) + else + error({reason=str_const.invalid_jwt}) + end +end + +--@function jwt encode : it converts into base64 encoded string. if input is a table, it convets into +-- json before converting to base64 string +--@param payloaf +--@return base64 encoded payloaf +function _M.jwt_encode(self, ori, is_payload) + if type(ori) == str_const.table then + ori = is_payload and get_payload_encoder(self)(ori) or cjson_encode(ori) + end + local res = ngx_encode_base64(ori):gsub(str_const.plus, str_const.dash):gsub(str_const.slash, str_const.underscore):gsub(str_const.equal, str_const.empty) + return res +end + + + +--@function jwt decode : decode bas64 encoded string +function _M.jwt_decode(self, b64_str, json_decode, is_payload) + b64_str = b64_str:gsub(str_const.dash, str_const.plus):gsub(str_const.underscore, str_const.slash) + + local reminder = #b64_str % 4 + if reminder > 0 then + b64_str = b64_str .. string_rep(str_const.equal, 4 - reminder) + end + local data = ngx_decode_base64(b64_str) + if not data then + return nil + end + if json_decode then + data = is_payload and get_payload_decoder(self)(data) or cjson_decode(data) + end + return data +end + +--- Initialize the trusted certs +-- During RS256 verify, we'll make sure the +-- cert was signed by one of these +function _M.set_trusted_certs_file(self, filename) + self.trusted_certs_file = filename +end +_M.trusted_certs_file = nil + +--- Set a whitelist of allowed algorithms +-- E.g., jwt:set_alg_whitelist({RS256=1,HS256=1}) +-- +-- @param algorithms - A table with keys for the supported algorithms +-- If the table is non-nil, during +-- verify, the alg must be in the table +function _M.set_alg_whitelist(self, algorithms) + self.alg_whitelist = algorithms +end + +_M.alg_whitelist = nil + + +--- Returns the list of default validations that will be +--- applied upon the verification of a jwt. +function _M.get_default_validation_options(self, jwt_obj) + return { + [str_const.require_exp_claim]=jwt_obj[str_const.payload].exp ~= nil, + [str_const.require_nbf_claim]=jwt_obj[str_const.payload].nbf ~= nil + } +end + +--- Set a function used to retrieve the content of x5u urls +-- +-- @param retriever_function - A pointer to a function. This function should be +-- defined to accept three string parameters. First one +-- will be the value of the 'x5u' attribute. Second +-- one will be the value of the 'iss' attribute, would +-- it be defined in the jwt. Third one will be the value +-- of the 'iss' attribute, would it be defined in the jwt. +-- This function should return the matching certificate. +function _M.set_x5u_content_retriever(self, retriever_function) + if type(retriever_function) ~= str_const.funct then + error("'retriever_function' is expected to be a function", 0) + end + self.x5u_content_retriever = retriever_function +end + +_M.x5u_content_retriever = nil + +-- https://tools.ietf.org/html/rfc7516#appendix-B.3 +-- TODO: do it in lua way +local function binlen(s) + if type(s) ~= 'string' then return end + + local len = 8 * #s + + return string_char(len / 0x0100000000000000 % 0x100) + .. string_char(len / 0x0001000000000000 % 0x100) + .. string_char(len / 0x0000010000000000 % 0x100) + .. string_char(len / 0x0000000100000000 % 0x100) + .. string_char(len / 0x0000000001000000 % 0x100) + .. string_char(len / 0x0000000000010000 % 0x100) + .. string_char(len / 0x0000000000000100 % 0x100) + .. string_char(len / 0x0000000000000001 % 0x100) +end + +--@function sign jwe payload +--@param secret key : if used pre-shared or RSA key +--@param jwe payload +--@return jwe token +local function sign_jwe(self, secret_key, jwt_obj) + local header = jwt_obj.header + local enc = header.enc + local alg = header.alg + + -- remove type + if header.typ then + header.typ = nil + end + + -- TODO: implement logic for creating enc key and mac key and then encrypt key + local key, encrypted_key, mac_key, enc_key + local encoded_header = _M:jwt_encode(header) + local payload_to_encrypt = get_payload_encoder(self)(jwt_obj.payload) + if alg == str_const.DIR then + _, mac_key, enc_key = derive_keys(enc, secret_key) + encrypted_key = "" + elseif alg == str_const.RSA_OAEP_256 then + local cert, err + if secret_key:find("CERTIFICATE") then + cert, err = evp.Cert:new(secret_key) + elseif secret_key:find("PUBLIC KEY") then + cert, err = evp.PublicKey:new(secret_key) + end + if not cert then + error({reason="Decode secret is not a valid cert/public key: " .. (err and err or secret_key)}) + end + local rsa_encryptor = evp.RSAEncryptor:new(cert, evp.CONST.RSA_PKCS1_OAEP_PADDING, evp.CONST.SHA256_DIGEST) + if err then + error("failed to create rsa object for encryption ".. err) + end + key, mac_key, enc_key = derive_keys(enc) + encrypted_key, err = rsa_encryptor:encrypt(key) + if err or not encrypted_key then + error({reason="failed to encrypt key " .. (err or "")}) + end + else + error({reason="unsupported alg: " .. alg}) + end + + local cipher_text, iv, auth_tag, err = encrypt_payload(enc_key, payload_to_encrypt, enc, encoded_header) + if err then + error({reason="error while encrypting payload. Error: " .. err}) + end + + if not auth_tag then + local encoded_header_length = binlen(encoded_header) + local mac_input = table_concat({encoded_header , iv, cipher_text , encoded_header_length}) + local mac = hmac_digest(enc, mac_key, mac_input) + auth_tag = string_sub(mac, 1, #mac/2) + end + + local jwe_table = {encoded_header, _M:jwt_encode(encrypted_key), _M:jwt_encode(iv), + _M:jwt_encode(cipher_text), _M:jwt_encode(auth_tag)} + return table_concat(jwe_table, ".", 1, 5) +end + +--@function get_secret_str : returns the secret if it is a string, or the result of a function +--@param either the string secret or a function that takes a string parameter and returns a string or nil +--@param jwt payload +--@return the secret as a string or as a function +local function get_secret_str(secret_or_function, jwt_obj) + if type(secret_or_function) == str_const.funct then + -- Only use with hmac algorithms + local alg = jwt_obj[str_const.header][str_const.alg] + if alg ~= str_const.HS256 and alg ~= str_const.HS512 then + error({reason="secret function can only be used with hmac alg: " .. alg}) + end + + -- Pull out the kid value from the header + local kid_val = jwt_obj[str_const.header][str_const.kid] + if kid_val == nil then + error({reason="secret function specified without kid in header"}) + end + + -- Call the function + return secret_or_function(kid_val) or error({reason="function returned nil for kid: " .. kid_val}) + elseif type(secret_or_function) == str_const.string then + -- Just return the string + return secret_or_function + else + -- Throw an error + error({reason="invalid secret type (must be string or function)"}) + end +end + +--@function sign : create a jwt/jwe signature from jwt_object +--@param secret key +--@param jwt/jwe payload +function _M.sign(self, secret_key, jwt_obj) + -- header typ check + local typ = jwt_obj[str_const.header][str_const.typ] + -- Optional header typ check [See http://tools.ietf.org/html/draft-ietf-oauth-json-web-token-25#section-5.1] + if typ ~= nil then + if typ ~= str_const.JWT and typ ~= str_const.JWE then + error({reason="invalid typ: " .. typ}) + end + end + + if typ == str_const.JWE or jwt_obj.header.enc then + return sign_jwe(self, secret_key, jwt_obj) + end + -- header alg check + local raw_header = get_raw_part(str_const.header, jwt_obj) + local raw_payload = get_raw_part(str_const.payload, jwt_obj) + local message = string_format(str_const.regex_join_msg, raw_header, raw_payload) + local alg = jwt_obj[str_const.header][str_const.alg] + local signature = "" + if alg == str_const.HS256 then + local secret_str = get_secret_str(secret_key, jwt_obj) + signature = hmac:new(secret_str, hmac.ALGOS.SHA256):final(message) + elseif alg == str_const.HS512 then + local secret_str = get_secret_str(secret_key, jwt_obj) + signature = hmac:new(secret_str, hmac.ALGOS.SHA512):final(message) + elseif alg == str_const.RS256 or alg == str_const.RS512 then + local signer, err = evp.RSASigner:new(secret_key) + if not signer then + error({reason="signer error: " .. err}) + end + if alg == str_const.RS256 then + signature = signer:sign(message, evp.CONST.SHA256_DIGEST) + elseif alg == str_const.RS512 then + signature = signer:sign(message, evp.CONST.SHA512_DIGEST) + end + elseif alg == str_const.ES256 or alg == str_const.ES512 then + local signer, err = evp.ECSigner:new(secret_key) + if not signer then + error({reason="signer error: " .. err}) + end + -- OpenSSL will generate a DER encoded signature that needs to be converted + local der_signature = "" + if alg == str_const.ES256 then + der_signature = signer:sign(message, evp.CONST.SHA256_DIGEST) + elseif alg == str_const.ES512 then + der_signature = signer:sign(message, evp.CONST.SHA512_DIGEST) + end + -- Perform DER to RAW signature conversion + signature, err = signer:get_raw_sig(der_signature) + if not signature then + error({reason="signature error: " .. err}) + end + else + error({reason="unsupported alg: " .. alg}) + end + -- return full jwt string + return string_format(str_const.regex_join_msg, message , _M:jwt_encode(signature)) + +end + +--@function load jwt +--@param jwt string token +--@param secret +function _M.load_jwt(self, jwt_str, secret) + local success, ret = pcall(parse, self, secret, jwt_str) + if not success then + return { + valid=false, + verified=false, + reason=ret[str_const.reason] or str_const.invalid_jwt + } + end + + local jwt_obj = ret + jwt_obj[str_const.verified] = false + jwt_obj[str_const.valid] = true + return jwt_obj +end + +--@function verify jwe object +--@param jwt object +--@return jwt object with reason whether verified or not +local function verify_jwe_obj(jwt_obj) + + if jwt_obj[str_const.header][str_const.enc] ~= str_const.A256GCM then -- tag gets authenticated during decryption + local _, mac_key, _ = derive_keys(jwt_obj.header.enc, jwt_obj.internal.key) + local encoded_header = jwt_obj.internal.encoded_header + + local encoded_header_length = binlen(encoded_header) + local mac_input = table_concat({encoded_header , jwt_obj.internal.iv, jwt_obj.internal.cipher_text, + encoded_header_length}) + local mac = hmac_digest(jwt_obj.header.enc, mac_key, mac_input) + local auth_tag = string_sub(mac, 1, #mac/2) + + if auth_tag ~= jwt_obj.signature then + jwt_obj[str_const.reason] = "signature mismatch: " .. + tostring(jwt_obj[str_const.signature]) + end + end + + jwt_obj.internal = nil + jwt_obj.signature = nil + + if not jwt_obj[str_const.reason] then + jwt_obj[str_const.verified] = true + jwt_obj[str_const.reason] = str_const.everything_awesome + end + + return jwt_obj +end + +--@function extract certificate +--@param jwt object +--@return decoded certificate +local function extract_certificate(jwt_obj, x5u_content_retriever) + local x5c = jwt_obj[str_const.header][str_const.x5c] + if x5c ~= nil and x5c[1] ~= nil then + -- TODO Might want to add support for intermediaries that we + -- don't have in our trusted chain (items 2... if present) + + local cert_str = ngx_decode_base64(x5c[1]) + if not cert_str then + jwt_obj[str_const.reason] = "Malformed x5c header" + end + + return cert_str + end + + local x5u = jwt_obj[str_const.header][str_const.x5u] + if x5u ~= nil then + -- TODO Ensure the url starts with https:// + -- cf. https://tools.ietf.org/html/rfc7517#section-4.6 + + if x5u_content_retriever == nil then + jwt_obj[str_const.reason] = "No function has been provided to retrieve the content pointed at by the 'x5u'." + return nil + end + + -- TODO Maybe validate the url against an optional list whitelisted url prefixes? + -- cf. https://news.ycombinator.com/item?id=9302394 + + local iss = jwt_obj[str_const.payload][str_const.iss] + local kid = jwt_obj[str_const.header][str_const.kid] + local success, ret = pcall(x5u_content_retriever, x5u, iss, kid) + + if not success then + jwt_obj[str_const.reason] = "An error occured while invoking the x5u_content_retriever function." + return nil + end + + return ret + end + + -- TODO When both x5c and x5u are defined, the implementation should + -- ensure their content match + -- cf. https://tools.ietf.org/html/rfc7517#section-4.6 + + jwt_obj[str_const.reason] = "Unsupported RS256 key model" + return nil + -- TODO - Implement jwk and kid based models... +end + +local function get_claim_spec_from_legacy_options(self, options) + local claim_spec = { } + local jwt_validators = require "resty.jwt-validators" + + if options[str_const.valid_issuers] ~= nil then + claim_spec[str_const.iss] = jwt_validators.equals_any_of(options[str_const.valid_issuers]) + end + + if options[str_const.lifetime_grace_period] ~= nil then + jwt_validators.set_system_leeway(options[str_const.lifetime_grace_period] or 0) + + -- If we have a leeway set, then either an NBF or an EXP should also exist requireds are added below + if options[str_const.require_nbf_claim] ~= true and options[str_const.require_exp_claim] ~= true then + claim_spec[str_const.full_obj] = jwt_validators.require_one_of({ str_const.nbf, str_const.exp }) + end + end + + if not is_nil_or_boolean(options[str_const.require_nbf_claim]) then + error(string.format("'%s' validation option is expected to be a boolean.", str_const.require_nbf_claim), 0) + end + + if not is_nil_or_boolean(options[str_const.require_exp_claim]) then + error(string.format("'%s' validation option is expected to be a boolean.", str_const.require_exp_claim), 0) + end + + if options[str_const.lifetime_grace_period] ~= nil or options[str_const.require_nbf_claim] ~= nil or options[str_const.require_exp_claim] ~= nil then + if options[str_const.require_nbf_claim] == true then + claim_spec[str_const.nbf] = jwt_validators.is_not_before() + else + claim_spec[str_const.nbf] = jwt_validators.opt_is_not_before() + end + + if options[str_const.require_exp_claim] == true then + claim_spec[str_const.exp] = jwt_validators.is_not_expired() + else + claim_spec[str_const.exp] = jwt_validators.opt_is_not_expired() + end + end + + return claim_spec +end + +local function is_legacy_validation_options(options) + + -- Validation options MUST be a table + if type(options) ~= str_const.table then + return false + end + + -- Validation options MUST have at least one of these, and must ONLY have these + local legacy_options = { } + legacy_options[str_const.valid_issuers]=1 + legacy_options[str_const.lifetime_grace_period]=1 + legacy_options[str_const.require_nbf_claim]=1 + legacy_options[str_const.require_exp_claim]=1 + + local is_legacy = false + for k in pairs(options) do + if legacy_options[k] ~= nil then + is_legacy = true + else + return false + end + end + return is_legacy +end + +-- Validates the claims for the given (parsed) object +local function validate_claims(self, jwt_obj, ...) + local claim_specs = {...} + if #claim_specs == 0 then + table.insert(claim_specs, _M:get_default_validation_options(jwt_obj)) + end + + if jwt_obj[str_const.reason] ~= nil then + return false + end + + -- Encode the current jwt_obj and use it when calling the individual validation functions + local jwt_json = cjson_encode(jwt_obj) + + -- Validate all our specs + for _, claim_spec in ipairs(claim_specs) do + if is_legacy_validation_options(claim_spec) then + claim_spec = get_claim_spec_from_legacy_options(self, claim_spec) + end + for claim, fx in pairs(claim_spec) do + if type(fx) ~= str_const.funct then + error("Claim spec value must be a function - see jwt-validators.lua for helper functions", 0) + end + + local val = claim == str_const.full_obj and cjson_decode(jwt_json) or jwt_obj.payload[claim] + local success, ret = pcall(fx, val, claim, jwt_json) + if not success then + jwt_obj[str_const.reason] = ret.reason or string.gsub(ret, "^.-:%d-: ", "") + return false + elseif ret == false then + jwt_obj[str_const.reason] = string.format("Claim '%s' ('%s') returned failure", claim, val) + return false + end + end + end + + -- Everything was good + return true +end + +--@function verify jwt object +--@param secret +--@param jwt_object +--@leeway +--@return verified jwt payload or jwt object with error code +function _M.verify_jwt_obj(self, secret, jwt_obj, ...) + if not jwt_obj.valid then + return jwt_obj + end + + -- validate any claims that have been passed in + if not validate_claims(self, jwt_obj, ...) then + return jwt_obj + end + + -- if jwe, invoked verify jwe + if jwt_obj[str_const.header][str_const.enc] then + return verify_jwe_obj(jwt_obj) + end + + local alg = jwt_obj[str_const.header][str_const.alg] + + local jwt_str = string_format(str_const.regex_jwt_join_str, jwt_obj.raw_header , jwt_obj.raw_payload , jwt_obj.signature) + + if self.alg_whitelist ~= nil then + if self.alg_whitelist[alg] == nil then + return {verified=false, reason="whitelist unsupported alg: " .. alg} + end + end + + if alg == str_const.HS256 or alg == str_const.HS512 then + local success, ret = pcall(_M.sign, self, secret, jwt_obj) + if not success then + -- syntax check + jwt_obj[str_const.reason] = ret[str_const.reason] or str_const.internal_error + elseif jwt_str ~= ret then + -- signature check + jwt_obj[str_const.reason] = "signature mismatch: " .. jwt_obj[str_const.signature] + end + elseif alg == str_const.RS256 or alg == str_const.RS512 or alg == str_const.ES256 or alg == str_const.ES512 then + local cert, err + if self.trusted_certs_file ~= nil then + local cert_str = extract_certificate(jwt_obj, self.x5u_content_retriever) + if not cert_str then + return jwt_obj + end + cert, err = evp.Cert:new(cert_str) + if not cert then + jwt_obj[str_const.reason] = "Unable to extract signing cert from JWT: " .. err + return jwt_obj + end + -- Try validating against trusted CA's, then a cert passed as secret + local trusted = cert:verify_trust(self.trusted_certs_file) + if not trusted then + jwt_obj[str_const.reason] = "Cert used to sign the JWT isn't trusted: " .. err + return jwt_obj + end + elseif secret ~= nil then + if secret:find("CERTIFICATE") then + cert, err = evp.Cert:new(secret) + elseif secret:find("PUBLIC KEY") then + cert, err = evp.PublicKey:new(secret) + end + if not cert then + jwt_obj[str_const.reason] = "Decode secret is not a valid cert/public key" + return jwt_obj + end + else + jwt_obj[str_const.reason] = "No trusted certs loaded" + return jwt_obj + end + local verifier = '' + if alg == str_const.RS256 or alg == str_const.RS512 then + verifier = evp.RSAVerifier:new(cert) + elseif alg == str_const.ES256 or alg == str_const.ES512 then + verifier = evp.ECVerifier:new(cert) + end + if not verifier then + -- Internal error case, should not happen... + jwt_obj[str_const.reason] = "Failed to build verifier " .. err + return jwt_obj + end + + -- assemble jwt parts + local raw_header = get_raw_part(str_const.header, jwt_obj) + local raw_payload = get_raw_part(str_const.payload, jwt_obj) + + local message =string_format(str_const.regex_join_msg, raw_header , raw_payload) + local sig = _M:jwt_decode(jwt_obj[str_const.signature], false) + + if not sig then + jwt_obj[str_const.reason] = "Wrongly encoded signature" + return jwt_obj + end + + local verified = false + err = "verify error: reason unknown" + + if alg == str_const.RS256 or alg == str_const.ES256 then + verified, err = verifier:verify(message, sig, evp.CONST.SHA256_DIGEST) + elseif alg == str_const.RS512 or alg == str_const.ES512 then + verified, err = verifier:verify(message, sig, evp.CONST.SHA512_DIGEST) + end + if not verified then + jwt_obj[str_const.reason] = err + end + else + jwt_obj[str_const.reason] = "Unsupported algorithm " .. alg + end + + if not jwt_obj[str_const.reason] then + jwt_obj[str_const.verified] = true + jwt_obj[str_const.reason] = str_const.everything_awesome + end + return jwt_obj + +end + + +function _M.verify(self, secret, jwt_str, ...) + local jwt_obj = _M.load_jwt(self, jwt_str, secret) + if not jwt_obj.valid then + return {verified=false, reason=jwt_obj[str_const.reason]} + end + return _M.verify_jwt_obj(self, secret, jwt_obj, ...) + +end + +function _M.set_payload_encoder(self, encoder) + if type(encoder) ~= "function" then + error({reason="payload encoder must be function"}) + end + self.payload_encoder = encoder +end + + +function _M.set_payload_decoder(self, decoder) + if type(decoder) ~= "function" then + error({reason="payload decoder must be function"}) + end + self.payload_decoder= decoder +end + + +function _M.new() + return setmetatable({}, mt) +end + +return _M diff --git a/server/resty/openidc.lua b/server/resty/openidc.lua new file mode 100644 index 0000000..246414e --- /dev/null +++ b/server/resty/openidc.lua @@ -0,0 +1,1870 @@ +--[[ +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. + +*************************************************************************** +Copyright (C) 2017-2019 ZmartZone IAM +Copyright (C) 2015-2017 Ping Identity Corporation +All rights reserved. + +For further information please contact: + + Ping Identity Corporation + 1099 18th St Suite 2950 + Denver, CO 80202 + 303.468.2900 + http://www.pingidentity.com + +DISCLAIMER OF WARRANTIES: + +THE SOFTWARE PROVIDED HEREUNDER IS PROVIDED ON AN "AS IS" BASIS, WITHOUT +ANY WARRANTIES OR REPRESENTATIONS EXPRESS, IMPLIED OR STATUTORY; INCLUDING, +WITHOUT LIMITATION, WARRANTIES OF QUALITY, PERFORMANCE, NONINFRINGEMENT, +MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. NOR ARE THERE ANY +WARRANTIES CREATED BY A COURSE OR DEALING, COURSE OF PERFORMANCE OR TRADE +USAGE. FURTHERMORE, THERE ARE NO WARRANTIES THAT THE SOFTWARE WILL MEET +YOUR NEEDS OR BE FREE FROM ERRORS, OR THAT THE OPERATION OF THE SOFTWARE +WILL BE UNINTERRUPTED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +@Author: Hans Zandbelt - hans.zandbelt@zmartzone.eu +--]] + +local require = require +local cjson = require("cjson") +local cjson_s = require("cjson.safe") +local http = require("resty.http") +local r_session = require("resty.session") +local string = string +local ipairs = ipairs +local pairs = pairs +local type = type +local ngx = ngx +local b64 = ngx.encode_base64 +local unb64 = ngx.decode_base64 + +local log = ngx.log +local DEBUG = ngx.DEBUG +local ERROR = ngx.ERR +local WARN = ngx.WARN + +local function token_auth_method_precondition(method, required_field) + return function(opts) + if not opts[required_field] then + log(DEBUG, "Can't use " .. method .. " without opts." .. required_field) + return false + end + return true + end +end + +local supported_token_auth_methods = { + client_secret_basic = true, + client_secret_post = true, + private_key_jwt = token_auth_method_precondition('private_key_jwt', 'client_rsa_private_key'), + client_secret_jwt = token_auth_method_precondition('client_secret_jwt', 'client_secret') +} + +local openidc = { + _VERSION = "1.7.5" +} +openidc.__index = openidc + +local function store_in_session(opts, feature) + -- We don't have a whitelist of features to enable + if not opts.session_contents then + return true + end + + return opts.session_contents[feature] +end + +-- set value in server-wide cache if available +local function openidc_cache_set(type, key, value, exp) + local dict = ngx.shared[type] + if dict and (exp > 0) then + local success, err, forcible = dict:set(key, value, exp) + log(DEBUG, "cache set: success=", success, " err=", err, " forcible=", forcible) + end +end + +-- retrieve value from server-wide cache if available +local function openidc_cache_get(type, key) + local dict = ngx.shared[type] + local value + if dict then + value = dict:get(key) + if value then log(DEBUG, "cache hit: type=", type, " key=", key) end + end + return value +end + +-- invalidate values of server-wide cache +local function openidc_cache_invalidate(type) + local dict = ngx.shared[type] + if dict then + log(DEBUG, "flushing cache for " .. type) + dict.flush_all(dict) + local nbr = dict.flush_expired(dict) + end +end + +-- invalidate all server-wide caches +function openidc.invalidate_caches() + openidc_cache_invalidate("discovery") + openidc_cache_invalidate("jwks") + openidc_cache_invalidate("introspection") + openidc_cache_invalidate("jwt_verification") +end + +-- validate the contents of and id_token +local function openidc_validate_id_token(opts, id_token, nonce) + + -- check issuer + if opts.discovery.issuer ~= id_token.iss then + log(ERROR, "issuer \"", id_token.iss, "\" in id_token is not equal to the issuer from the discovery document \"", opts.discovery.issuer, "\"") + return false + end + + -- check sub + if not id_token.sub then + log(ERROR, "no \"sub\" claim found in id_token") + return false + end + + -- check nonce + if nonce and nonce ~= id_token.nonce then + log(ERROR, "nonce \"", id_token.nonce, "\" in id_token is not equal to the nonce that was sent in the request \"", nonce, "\"") + return false + end + + -- check issued-at timestamp + if not id_token.iat then + log(ERROR, "no \"iat\" claim found in id_token") + return false + end + + local slack = opts.iat_slack and opts.iat_slack or 120 + if id_token.iat > (ngx.time() + slack) then + log(ERROR, "id_token not yet valid: id_token.iat=", id_token.iat, ", ngx.time()=", ngx.time(), ", slack=", slack) + return false + end + + -- check expiry timestamp + if not id_token.exp then + log(ERROR, "no \"exp\" claim found in id_token") + return false + end + + if (id_token.exp + slack) < ngx.time() then + log(ERROR, "token expired: id_token.exp=", id_token.exp, ", ngx.time()=", ngx.time()) + return false + end + + -- check audience (array or string) + if not id_token.aud then + log(ERROR, "no \"aud\" claim found in id_token") + return false + end + + if (type(id_token.aud) == "table") then + for _, value in pairs(id_token.aud) do + if value == opts.client_id then + return true + end + end + log(ERROR, "no match found token audience array: client_id=", opts.client_id) + return false + elseif (type(id_token.aud) == "string") then + if id_token.aud ~= opts.client_id then + log(ERROR, "token audience does not match: id_token.aud=", id_token.aud, ", client_id=", opts.client_id) + return false + end + end + return true +end + +local function get_first(table_or_string) + local res = table_or_string + if table_or_string and type(table_or_string) == 'table' then + res = table_or_string[1] + end + return res +end + +local function get_first_header(headers, header_name) + local header = headers[header_name] + return get_first(header) +end + +local function get_first_header_and_strip_whitespace(headers, header_name) + local header = get_first_header(headers, header_name) + return header and header:gsub('%s', '') +end + +local function get_forwarded_parameter(headers, param_name) + local forwarded = get_first_header(headers, 'Forwarded') + local params = {} + if forwarded then + local function parse_parameter(pv) + local name, value = pv:match("^%s*([^=]+)%s*=%s*(.-)%s*$") + if name and value then + if value:sub(1, 1) == '"' then + value = value:sub(2, -2) + end + params[name:lower()] = value + end + end + + -- this assumes there is no quoted comma inside the header's value + -- which should be fine as comma is not legal inside a node name, + -- a URI scheme or a host name. The only thing that might bite us + -- are extensions. + local first_part = forwarded + local first_comma = forwarded:find("%s*,%s*") + if first_comma then + first_part = forwarded:sub(1, first_comma - 1) + end + first_part:gsub("[^;]+", parse_parameter) + end + return params[param_name:gsub("^%s*(.-)%s*$", "%1"):lower()] +end + +local function get_scheme(headers) + return get_forwarded_parameter(headers, 'proto') + or get_first_header_and_strip_whitespace(headers, 'X-Forwarded-Proto') + or ngx.var.scheme +end + +local function get_host_name_from_x_header(headers) + local header = get_first_header_and_strip_whitespace(headers, 'X-Forwarded-Host') + return header and header:gsub('^([^,]+),?.*$', '%1') +end + +local function get_host_name(headers) + return get_forwarded_parameter(headers, 'host') + or get_host_name_from_x_header(headers) + or ngx.var.http_host +end + +-- assemble the redirect_uri +local function openidc_get_redirect_uri(opts, session) + local path = opts.redirect_uri_path + if opts.redirect_uri then + if opts.redirect_uri:sub(1, 1) == '/' then + path = opts.redirect_uri + else + return opts.redirect_uri + end + end + local headers = ngx.req.get_headers() + local scheme = opts.redirect_uri_scheme or get_scheme(headers) + local host = get_host_name(headers) + if not host then + -- possibly HTTP 1.0 and no Host header + if session then session:close() end + ngx.exit(ngx.HTTP_BAD_REQUEST) + end + return scheme .. "://" .. host .. path +end + +-- perform base64url decoding +local function openidc_base64_url_decode(input) + local reminder = #input % 4 + if reminder > 0 then + local padlen = 4 - reminder + input = input .. string.rep('=', padlen) + end + input = input:gsub('%-', '+'):gsub('_', '/') + return unb64(input) +end + +-- perform base64url encoding +local function openidc_base64_url_encode(input) + local output = b64(input, true) + return output:gsub('%+', '-'):gsub('/', '_') +end + +local function openidc_combine_uri(uri, params) + if params == nil or next(params) == nil then + return uri + end + local sep = "?" + if string.find(uri, "?", 1, true) then + sep = "&" + end + return uri .. sep .. ngx.encode_args(params) +end + +local function decorate_request(http_request_decorator, req) + return http_request_decorator and http_request_decorator(req) or req +end + +local function openidc_s256(verifier) + local sha256 = (require 'resty.sha256'):new() + sha256:update(verifier) + return openidc_base64_url_encode(sha256:final()) +end + +-- send the browser of to the OP's authorization endpoint +local function openidc_authorize(opts, session, target_url, prompt) + local resty_random = require("resty.random") + local resty_string = require("resty.string") + local err + + -- generate state and nonce + local state = resty_string.to_hex(resty_random.bytes(16)) + local nonce = (opts.use_nonce == nil or opts.use_nonce) + and resty_string.to_hex(resty_random.bytes(16)) + local code_verifier = opts.use_pkce and openidc_base64_url_encode(resty_random.bytes(32)) + + -- assemble the parameters to the authentication request + local params = { + client_id = opts.client_id, + response_type = "code", + scope = opts.scope and opts.scope or "openid email profile", + redirect_uri = openidc_get_redirect_uri(opts, session), + state = state, + } + + if nonce then + params.nonce = nonce + end + + if prompt then + params.prompt = prompt + end + + if opts.display then + params.display = opts.display + end + + if code_verifier then + params.code_challenge_method = 'S256' + params.code_challenge = openidc_s256(code_verifier) + end + + -- merge any provided extra parameters + if opts.authorization_params then + for k, v in pairs(opts.authorization_params) do params[k] = v end + end + + -- store state in the session + session.data.original_url = target_url + session.data.state = state + session.data.nonce = nonce + session.data.code_verifier = code_verifier + session.data.last_authenticated = ngx.time() + + if opts.lifecycle and opts.lifecycle.on_created then + err = opts.lifecycle.on_created(session) + if err then + log(WARN, "failed in `on_created` handler: " .. err) + return err + end + end + + session:save() + + -- redirect to the /authorization endpoint + ngx.header["Cache-Control"] = "no-cache, no-store, max-age=0" + return ngx.redirect(openidc_combine_uri(opts.discovery.authorization_endpoint, params)) +end + +-- parse the JSON result from a call to the OP +local function openidc_parse_json_response(response, ignore_body_on_success) + local ignore_body_on_success = ignore_body_on_success or false + + local err + local res + + -- check the response from the OP + if response.status ~= 200 then + err = "response indicates failure, status=" .. response.status .. ", body=" .. response.body + else + if ignore_body_on_success then + return nil, nil + end + + -- decode the response and extract the JSON object + res = cjson_s.decode(response.body) + + if not res then + err = "JSON decoding failed" + end + end + + return res, err +end + +local function openidc_configure_timeouts(httpc, timeout) + if timeout then + if type(timeout) == "table" then + local r, e = httpc:set_timeouts(timeout.connect or 0, timeout.send or 0, timeout.read or 0) + else + local r, e = httpc:set_timeout(timeout) + end + end +end + +-- Set outgoing proxy options +local function openidc_configure_proxy(httpc, proxy_opts) + if httpc and proxy_opts and type(proxy_opts) == "table" then + log(DEBUG, "openidc_configure_proxy : use http proxy") + httpc:set_proxy_options(proxy_opts) + else + log(DEBUG, "openidc_configure_proxy : don't use http proxy") + end +end + +-- make a call to the token endpoint +function openidc.call_token_endpoint(opts, endpoint, body, auth, endpoint_name, ignore_body_on_success) + local ignore_body_on_success = ignore_body_on_success or false + + local ep_name = endpoint_name or 'token' + if not endpoint then + return nil, 'no endpoint URI for ' .. ep_name + end + + local headers = { + ["Content-Type"] = "application/x-www-form-urlencoded" + } + + if auth then + if auth == "client_secret_basic" then + if opts.client_secret then + headers.Authorization = "Basic " .. b64(ngx.escape_uri(opts.client_id) .. ":" .. ngx.escape_uri(opts.client_secret)) + else + -- client_secret must not be set if Windows Integrated Authentication (WIA) is used with + -- Active Directory Federation Services (AD FS) 4.0 (or newer) on Windows Server 2016 (or newer) + headers.Authorization = "Basic " .. b64(ngx.escape_uri(opts.client_id) .. ":") + end + log(DEBUG, "client_secret_basic: authorization header '" .. headers.Authorization .. "'") + + elseif auth == "client_secret_post" then + body.client_id = opts.client_id + if opts.client_secret then + body.client_secret = opts.client_secret + end + log(DEBUG, "client_secret_post: client_id and client_secret being sent in POST body") + + elseif auth == "private_key_jwt" or auth == "client_secret_jwt" then + local key = auth == "private_key_jwt" and opts.client_rsa_private_key or opts.client_secret + if not key then + return nil, "Can't use " .. auth .. " without a key." + end + body.client_id = opts.client_id + body.client_assertion_type = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + local now = ngx.time() + local assertion = { + header = { + typ = "JWT", + alg = auth == "private_key_jwt" and "RS256" or "HS256", + }, + payload = { + iss = opts.client_id, + sub = opts.client_id, + aud = endpoint, + jti = ngx.var.request_id, + exp = now + (opts.client_jwt_assertion_expires_in and opts.client_jwt_assertion_expires_in or 60), + iat = now + } + } + if auth == "private_key_jwt" then + assertion.header.kid = opts.client_rsa_private_key_id + end + + local r_jwt = require("resty.jwt") + body.client_assertion = r_jwt:sign(key, assertion) + log(DEBUG, auth .. ": client_id, client_assertion_type and client_assertion being sent in POST body") + end + end + + local pass_cookies = opts.pass_cookies + if pass_cookies then + if ngx.req.get_headers()["Cookie"] then + local t = {} + for cookie_name in string.gmatch(pass_cookies, "%S+") do + local cookie_value = ngx.var["cookie_" .. cookie_name] + if cookie_value then + table.insert(t, cookie_name .. "=" .. cookie_value) + end + end + headers.Cookie = table.concat(t, "; ") + end + end + + log(DEBUG, "request body for " .. ep_name .. " endpoint call: ", ngx.encode_args(body)) + + local httpc = http.new() + openidc_configure_timeouts(httpc, opts.timeout) + openidc_configure_proxy(httpc, opts.proxy_opts) + local res, err = httpc:request_uri(endpoint, decorate_request(opts.http_request_decorator, { + method = "POST", + body = ngx.encode_args(body), + headers = headers, + ssl_verify = (opts.ssl_verify ~= "no"), + keepalive = (opts.keepalive ~= "no") + })) + if not res then + err = "accessing " .. ep_name .. " endpoint (" .. endpoint .. ") failed: " .. err + log(ERROR, err) + return nil, err + end + + log(DEBUG, ep_name .. " endpoint response: ", res.body) + + return openidc_parse_json_response(res, ignore_body_on_success) +end + +-- computes access_token expires_in value (in seconds) +local function openidc_access_token_expires_in(opts, expires_in) + return (expires_in or opts.access_token_expires_in or 3600) - 1 - (opts.access_token_expires_leeway or 0) +end + +local function openidc_load_jwt_none_alg(enc_hdr, enc_payload) + local header = cjson_s.decode(openidc_base64_url_decode(enc_hdr)) + local payload = cjson_s.decode(openidc_base64_url_decode(enc_payload)) + if header and payload and header.alg == "none" then + return { + raw_header = enc_hdr, + raw_payload = enc_payload, + header = header, + payload = payload, + signature = '' + } + end + return nil +end + +-- get the Discovery metadata from the specified URL +local function openidc_discover(url, ssl_verify, keepalive, timeout, exptime, proxy_opts, http_request_decorator) + log(DEBUG, "openidc_discover: URL is: " .. url) + + local json, err + local v = openidc_cache_get("discovery", url) + if not v then + + log(DEBUG, "discovery data not in cache, making call to discovery endpoint") + -- make the call to the discovery endpoint + local httpc = http.new() + openidc_configure_timeouts(httpc, timeout) + openidc_configure_proxy(httpc, proxy_opts) + local res, error = httpc:request_uri(url, decorate_request(http_request_decorator, { + ssl_verify = (ssl_verify ~= "no"), + keepalive = (keepalive ~= "no") + })) + if not res then + err = "accessing discovery url (" .. url .. ") failed: " .. error + log(ERROR, err) + else + log(DEBUG, "response data: " .. res.body) + json, err = openidc_parse_json_response(res) + if json then + openidc_cache_set("discovery", url, cjson.encode(json), exptime or 24 * 60 * 60) + else + err = "could not decode JSON from Discovery data" .. (err and (": " .. err) or '') + log(ERROR, err) + end + end + + else + json = cjson.decode(v) + end + + return json, err +end + +-- turn a discovery url set in the opts dictionary into the discovered information +local function openidc_ensure_discovered_data(opts) + local err + if type(opts.discovery) == "string" then + local discovery + discovery, err = openidc_discover(opts.discovery, opts.ssl_verify, opts.keepalive, opts.timeout, opts.discovery_expires_in, opts.proxy_opts, + opts.http_request_decorator) + if not err then + opts.discovery = discovery + end + end + return err +end + +-- make a call to the userinfo endpoint +function openidc.call_userinfo_endpoint(opts, access_token) + local err = openidc_ensure_discovered_data(opts) + if err then + return nil, err + end + if not (opts and opts.discovery and opts.discovery.userinfo_endpoint) then + log(DEBUG, "no userinfo endpoint supplied") + return nil, nil + end + + local headers = { + ["Authorization"] = "Bearer " .. access_token, + } + + log(DEBUG, "authorization header '" .. headers.Authorization .. "'") + + local httpc = http.new() + openidc_configure_timeouts(httpc, opts.timeout) + openidc_configure_proxy(httpc, opts.proxy_opts) + local res, err = httpc:request_uri(opts.discovery.userinfo_endpoint, + decorate_request(opts.http_request_decorator, { + headers = headers, + ssl_verify = (opts.ssl_verify ~= "no"), + keepalive = (opts.keepalive ~= "no") + })) + if not res then + err = "accessing (" .. opts.discovery.userinfo_endpoint .. ") failed: " .. err + return nil, err + end + + log(DEBUG, "userinfo response: ", res.body) + + -- parse the response from the user info endpoint + return openidc_parse_json_response(res) +end + +local function can_use_token_auth_method(method, opts) + local supported = supported_token_auth_methods[method] + return supported and (type(supported) ~= 'function' or supported(opts)) +end + +-- get the token endpoint authentication method +local function openidc_get_token_auth_method(opts) + + if opts.token_endpoint_auth_method ~= nil and not can_use_token_auth_method(opts.token_endpoint_auth_method, opts) then + log(ERROR, "configured value for token_endpoint_auth_method (" .. opts.token_endpoint_auth_method .. ") is not supported, ignoring it") + opts.token_endpoint_auth_method = nil + end + + local result + if opts.discovery.token_endpoint_auth_methods_supported ~= nil then + -- if set check to make sure the discovery data includes the selected client auth method + if opts.token_endpoint_auth_method ~= nil then + for index, value in ipairs(opts.discovery.token_endpoint_auth_methods_supported) do + log(DEBUG, index .. " => " .. value) + if value == opts.token_endpoint_auth_method then + log(DEBUG, "configured value for token_endpoint_auth_method (" .. opts.token_endpoint_auth_method .. ") found in token_endpoint_auth_methods_supported in metadata") + result = opts.token_endpoint_auth_method + break + end + end + if result == nil then + log(ERROR, "configured value for token_endpoint_auth_method (" .. opts.token_endpoint_auth_method .. ") NOT found in token_endpoint_auth_methods_supported in metadata") + return nil + end + else + for index, value in ipairs(opts.discovery.token_endpoint_auth_methods_supported) do + log(DEBUG, index .. " => " .. value) + if can_use_token_auth_method(value, opts) then + result = value + log(DEBUG, "no configuration setting for option so select the first supported method specified by the OP: " .. result) + break + end + end + end + else + result = opts.token_endpoint_auth_method + end + + -- set a sane default if auto-configuration failed + if result == nil then + result = "client_secret_basic" + end + + log(DEBUG, "token_endpoint_auth_method result set to " .. result) + + return result +end + +-- ensure that discovery and token auth configuration is available in opts +local function ensure_config(opts) + local err + err = openidc_ensure_discovered_data(opts) + if err then + return err + end + + -- set the authentication method for the token endpoint + opts.token_endpoint_auth_method = openidc_get_token_auth_method(opts) +end + +-- query for discovery endpoint data +function openidc.get_discovery_doc(opts) + local err = openidc_ensure_discovered_data(opts) + if err then + log(ERROR, "error getting endpoints definition using discovery endpoint") + end + + return opts.discovery, err +end + +local function openidc_jwks(url, force, ssl_verify, keepalive, timeout, exptime, proxy_opts, http_request_decorator) + log(DEBUG, "openidc_jwks: URL is: " .. url .. " (force=" .. force .. ") (decorator=" .. (http_request_decorator and type(http_request_decorator) or "nil")) + + local json, err, v + + if force == 0 then + v = openidc_cache_get("jwks", url) + end + + if not v then + + log(DEBUG, "cannot use cached JWKS data; making call to jwks endpoint") + -- make the call to the jwks endpoint + local httpc = http.new() + openidc_configure_timeouts(httpc, timeout) + openidc_configure_proxy(httpc, proxy_opts) + local res, error = httpc:request_uri(url, decorate_request(http_request_decorator, { + ssl_verify = (ssl_verify ~= "no"), + keepalive = (keepalive ~= "no") + })) + if not res then + err = "accessing jwks url (" .. url .. ") failed: " .. error + log(ERROR, err) + else + log(DEBUG, "response data: " .. res.body) + json, err = openidc_parse_json_response(res) + if json then + openidc_cache_set("jwks", url, cjson.encode(json), exptime or 24 * 60 * 60) + end + end + + else + json = cjson.decode(v) + end + + return json, err +end + +local function split_by_chunk(text, chunkSize) + local s = {} + for i = 1, #text, chunkSize do + s[#s + 1] = text:sub(i, i + chunkSize - 1) + end + return s +end + +local function get_jwk(keys, kid) + + local rsa_keys = {} + for _, value in pairs(keys) do + if value.kty == "RSA" and (not value.use or value.use == "sig") then + table.insert(rsa_keys, value) + end + end + + if kid == nil then + if #rsa_keys == 1 then + log(DEBUG, "returning only RSA key of JWKS for keyid-less JWT") + return rsa_keys[1], nil + else + return nil, "JWT doesn't specify kid but the keystore contains multiple RSA keys" + end + end + for _, value in pairs(rsa_keys) do + if value.kid == kid then + return value, nil + end + end + + return nil, "RSA key with id " .. kid .. " not found" +end + +local wrap = ('.'):rep(64) + +local envelope = "-----BEGIN %s-----\n%s\n-----END %s-----\n" + +local function der2pem(data, typ) + typ = typ:upper() or "CERTIFICATE" + data = b64(data) + return string.format(envelope, typ, data:gsub(wrap, '%0\n', (#data - 1) / 64), typ) +end + + +local function encode_length(length) + if length < 0x80 then + return string.char(length) + elseif length < 0x100 then + return string.char(0x81, length) + elseif length < 0x10000 then + return string.char(0x82, math.floor(length / 0x100), length % 0x100) + end + error("Can't encode lengths over 65535") +end + + +local function encode_sequence(array, of) + local encoded_array = array + if of then + encoded_array = {} + for i = 1, #array do + encoded_array[i] = of(array[i]) + end + end + encoded_array = table.concat(encoded_array) + + return string.char(0x30) .. encode_length(#encoded_array) .. encoded_array +end + +local function encode_binary_integer(bytes) + if bytes:byte(1) > 127 then + -- We currenly only use this for unsigned integers, + -- however since the high bit is set here, it would look + -- like a negative signed int, so prefix with zeroes + bytes = "\0" .. bytes + end + return "\2" .. encode_length(#bytes) .. bytes +end + +local function encode_sequence_of_integer(array) + return encode_sequence(array, encode_binary_integer) +end + +local function encode_bit_string(array) + local s = "\0" .. array -- first octet holds the number of unused bits + return "\3" .. encode_length(#s) .. s +end + +local function openidc_pem_from_x5c(x5c) + log(DEBUG, "Found x5c, getting PEM public key from x5c entry of json public key") + local chunks = split_by_chunk(b64(openidc_base64_url_decode(x5c[1])), 64) + local pem = "-----BEGIN CERTIFICATE-----\n" .. + table.concat(chunks, "\n") .. + "\n-----END CERTIFICATE-----" + log(DEBUG, "Generated PEM key from x5c:", pem) + return pem +end + +local function openidc_pem_from_rsa_n_and_e(n, e) + log(DEBUG, "getting PEM public key from n and e parameters of json public key") + + local der_key = { + openidc_base64_url_decode(n), openidc_base64_url_decode(e) + } + local encoded_key = encode_sequence_of_integer(der_key) + local pem = der2pem(encode_sequence({ + encode_sequence({ + "\6\9\42\134\72\134\247\13\1\1\1" -- OID :rsaEncryption + .. "\5\0" -- ASN.1 NULL of length 0 + }), + encode_bit_string(encoded_key) + }), "PUBLIC KEY") + log(DEBUG, "Generated pem key from n and e: ", pem) + return pem +end + +local function openidc_pem_from_jwk(opts, kid) + local err = openidc_ensure_discovered_data(opts) + if err then + return nil, err + end + + if not opts.discovery.jwks_uri or not (type(opts.discovery.jwks_uri) == "string") or (opts.discovery.jwks_uri == "") then + return nil, "opts.discovery.jwks_uri is not present or not a string" + end + + local cache_id = opts.discovery.jwks_uri .. '#' .. (kid or '') + local v = openidc_cache_get("jwks", cache_id) + + if v then + return v + end + + local jwk, jwks + + for force = 0, 1 do + jwks, err = openidc_jwks(opts.discovery.jwks_uri, force, opts.ssl_verify, opts.keepalive, opts.timeout, opts.jwk_expires_in, opts.proxy_opts, + opts.http_request_decorator) + if err then + return nil, err + end + + jwk, err = get_jwk(jwks.keys, kid) + + if jwk and not err then + break + end + end + + if err then + return nil, err + end + + local x5c = jwk.x5c + if x5c and #(jwk.x5c) == 0 then + log(WARN, "Found invalid JWK with empty x5c array, ignoring x5c claim") + x5c = nil + end + + local pem + if x5c then + pem = openidc_pem_from_x5c(x5c) + elseif jwk.kty == "RSA" and jwk.n and jwk.e then + pem = openidc_pem_from_rsa_n_and_e(jwk.n, jwk.e) + else + return nil, "don't know how to create RSA key/cert for " .. cjson.encode(jwk) + end + + openidc_cache_set("jwks", cache_id, pem, opts.jwk_expires_in or 24 * 60 * 60) + return pem +end + +-- does lua-resty-jwt and/or we know how to handle the algorithm of the JWT? +local function is_algorithm_supported(jwt_header) + return jwt_header and jwt_header.alg and (jwt_header.alg == "none" + or string.sub(jwt_header.alg, 1, 2) == "RS" + or string.sub(jwt_header.alg, 1, 2) == "HS") +end + +-- is the JWT signing algorithm an asymmetric one whose key might be +-- obtained from the discovery endpoint? +local function uses_asymmetric_algorithm(jwt_header) + return string.sub(jwt_header.alg, 1, 2) == "RS" +end + +-- is the JWT signing algorithm one that has been expected? +local function is_algorithm_expected(jwt_header, expected_algs) + if expected_algs == nil or not jwt_header or not jwt_header.alg then + return true + end + if type(expected_algs) == 'string' then + expected_algs = { expected_algs } + end + for _, alg in ipairs(expected_algs) do + if alg == jwt_header.alg then + return true + end + end + return false +end + +-- parse a JWT and verify its signature (if present) +local function openidc_load_jwt_and_verify_crypto(opts, jwt_string, asymmetric_secret, +symmetric_secret, expected_algs, ...) + local r_jwt = require("resty.jwt") + local enc_hdr, enc_payload, enc_sign = string.match(jwt_string, '^(.+)%.(.+)%.(.*)$') + if enc_payload and (not enc_sign or enc_sign == "") then + local jwt = openidc_load_jwt_none_alg(enc_hdr, enc_payload) + if jwt then + if opts.accept_none_alg then + log(DEBUG, "accept JWT with alg \"none\" and no signature") + return jwt + else + return jwt, "token uses \"none\" alg but accept_none_alg is not enabled" + end + end -- otherwise the JWT is invalid and load_jwt produces an error + end + + local jwt_obj = r_jwt:load_jwt(jwt_string, nil) + if not jwt_obj.valid then + local reason = "invalid jwt" + if jwt_obj.reason then + reason = reason .. ": " .. jwt_obj.reason + end + return nil, reason + end + + if not is_algorithm_expected(jwt_obj.header, expected_algs) then + local alg = jwt_obj.header and jwt_obj.header.alg or "no algorithm at all" + return nil, "token is signed by unexpected algorithm \"" .. alg .. "\"" + end + + local secret + if is_algorithm_supported(jwt_obj.header) then + if uses_asymmetric_algorithm(jwt_obj.header) then + if opts.secret then + log(WARN, "using deprecated option `opts.secret` for asymmetric key; switch to `opts.public_key` instead") + end + secret = asymmetric_secret or opts.secret + if not secret and opts.discovery then + log(DEBUG, "using discovery to find key") + local err + secret, err = openidc_pem_from_jwk(opts, jwt_obj.header.kid) + + if secret == nil then + log(ERROR, err) + return nil, err + end + end + else + if opts.secret then + log(WARN, "using deprecated option `opts.secret` for symmetric key; switch to `opts.symmetric_key` instead") + end + secret = symmetric_secret or opts.secret + end + end + + if #{ ... } == 0 then + -- an empty list of claim specs makes lua-resty-jwt add default + -- validators for the exp and nbf claims if they are + -- present. These validators need to know the configured slack + -- value + local jwt_validators = require("resty.jwt-validators") + jwt_validators.set_system_leeway(opts.iat_slack and opts.iat_slack or 120) + end + + jwt_obj = r_jwt:verify_jwt_obj(secret, jwt_obj, ...) + if jwt_obj then + log(DEBUG, "jwt: ", cjson.encode(jwt_obj), " ,valid: ", jwt_obj.valid, ", verified: ", jwt_obj.verified) + end + if not jwt_obj.verified then + local reason = "jwt signature verification failed" + if jwt_obj.reason then + reason = reason .. ": " .. jwt_obj.reason + end + return jwt_obj, reason + end + return jwt_obj +end + +-- +-- Load and validate id token from the id_token properties of the token endpoint response +-- Parameters : +-- - opts the openidc module options +-- - jwt_id_token the id_token from the id_token properties of the token endpoint response +-- - session the current session +-- Return the id_token, nil if valid +-- Return nil, the error if invalid +-- +local function openidc_load_and_validate_jwt_id_token(opts, jwt_id_token, session) + + local jwt_obj, err = openidc_load_jwt_and_verify_crypto(opts, jwt_id_token, opts.public_key, opts.client_secret, + opts.discovery.id_token_signing_alg_values_supported) + if err then + local alg = (jwt_obj and jwt_obj.header and jwt_obj.header.alg) or '' + local is_unsupported_signature_error = jwt_obj and not jwt_obj.verified and not is_algorithm_supported(jwt_obj.header) + if is_unsupported_signature_error then + if opts.accept_unsupported_alg == nil or opts.accept_unsupported_alg then + log(WARN, "ignored id_token signature as algorithm '" .. alg .. "' is not supported") + else + err = "token is signed using algorithm \"" .. alg .. "\" which is not supported by lua-resty-jwt" + log(ERROR, err) + return nil, err + end + else + log(ERROR, "id_token '" .. alg .. "' signature verification failed") + return nil, err + end + end + local id_token = jwt_obj.payload + + log(DEBUG, "id_token header: ", cjson.encode(jwt_obj.header)) + log(DEBUG, "id_token payload: ", cjson.encode(jwt_obj.payload)) + + -- validate the id_token contents + if openidc_validate_id_token(opts, id_token, session.data.nonce) == false then + err = "id_token validation failed" + log(ERROR, err) + return nil, err + end + + return id_token +end + +-- handle a "code" authorization response from the OP +local function openidc_authorization_response(opts, session) + local args = ngx.req.get_uri_args() + local err, log_err, client_err + + if not args.code or not args.state then + err = "unhandled request to the redirect_uri: " .. ngx.var.request_uri + log(ERROR, err) + return nil, err, session.data.original_url, session + end + + -- check that the state returned in the response against the session; prevents CSRF + if args.state ~= session.data.state then + log_err = "state from argument: " .. (args.state and args.state or "nil") .. " does not match state restored from session: " .. (session.data.state and session.data.state or "nil") + client_err = "state from argument does not match state restored from session" + log(ERROR, log_err) + return nil, client_err, session.data.original_url, session + end + + err = ensure_config(opts) + if err then + return nil, err, session.data.original_url, session + end + + -- check the iss if returned from the OP + if args.iss and args.iss ~= opts.discovery.issuer then + log_err = "iss from argument: " .. args.iss .. " does not match expected issuer: " .. opts.discovery.issuer + client_err = "iss from argument does not match expected issuer" + log(ERROR, log_err) + return nil, client_err, session.data.original_url, session + end + + -- check the client_id if returned from the OP + if args.client_id and args.client_id ~= opts.client_id then + log_err = "client_id from argument: " .. args.client_id .. " does not match expected client_id: " .. opts.client_id + client_err = "client_id from argument does not match expected client_id" + log(ERROR, log_err) + return nil, client_err, session.data.original_url, session + end + + -- assemble the parameters to the token endpoint + local body = { + grant_type = "authorization_code", + code = args.code, + redirect_uri = openidc_get_redirect_uri(opts, session), + state = session.data.state, + code_verifier = session.data.code_verifier + } + + log(DEBUG, "Authentication with OP done -> Calling OP Token Endpoint to obtain tokens") + + local current_time = ngx.time() + -- make the call to the token endpoint + local json + json, err = openidc.call_token_endpoint(opts, opts.discovery.token_endpoint, body, opts.token_endpoint_auth_method) + if err then + return nil, err, session.data.original_url, session + end + + local id_token, err = openidc_load_and_validate_jwt_id_token(opts, json.id_token, session); + if err then + return nil, err, session.data.original_url, session + end + + -- mark this sessions as authenticated + session.data.authenticated = true + -- clear state, nonce and code_verifier to protect against potential misuse + session.data.nonce = nil + session.data.state = nil + session.data.code_verifier = nil + if store_in_session(opts, 'id_token') then + session.data.id_token = id_token + end + + if store_in_session(opts, 'user') then + -- call the user info endpoint + -- TODO: should this error be checked? + local user + user, err = openidc.call_userinfo_endpoint(opts, json.access_token) + + if err then + log(ERROR, "error calling userinfo endpoint: " .. err) + elseif user then + if id_token.sub ~= user.sub then + err = "\"sub\" claim in id_token (\"" .. (id_token.sub or "null") .. "\") is not equal to the \"sub\" claim returned from the userinfo endpoint (\"" .. (user.sub or "null") .. "\")" + log(ERROR, err) + else + session.data.user = user + end + end + end + + if store_in_session(opts, 'enc_id_token') then + session.data.enc_id_token = json.id_token + end + + if store_in_session(opts, 'access_token') then + session.data.access_token = json.access_token + session.data.access_token_expiration = current_time + + openidc_access_token_expires_in(opts, json.expires_in) + if json.refresh_token ~= nil then + session.data.refresh_token = json.refresh_token + end + end + + if opts.lifecycle and opts.lifecycle.on_authenticated then + err = opts.lifecycle.on_authenticated(session, id_token, json) + if err then + log(WARN, "failed in `on_authenticated` handler: " .. err) + return nil, err, session.data.original_url, session + end + end + + -- save the session with the obtained id_token + session:save() + + -- redirect to the URL that was accessed originally + log(DEBUG, "OIDC Authorization Code Flow completed -> Redirecting to original URL (" .. session.data.original_url .. ")") + ngx.redirect(session.data.original_url) + return nil, nil, session.data.original_url, session +end + +-- token revocation (RFC 7009) +local function openidc_revoke_token(opts, token_type_hint, token) + if not opts.discovery.revocation_endpoint then + log(DEBUG, "no revocation endpoint supplied. unable to revoke " .. token_type_hint .. ".") + return nil + end + + local token_type_hint = token_type_hint or nil + local body = { + token = token + } + if token_type_hint then + body['token_type_hint'] = token_type_hint + end + local token_type_log = token_type_hint or 'token' + + -- ensure revocation endpoint auth method is properly discovered + local err = ensure_config(opts) + if err then + log(ERROR, "revocation of " .. token_type_log .. " unsuccessful: " .. err) + return false + end + + -- call the revocation endpoint + local _ + _, err = openidc.call_token_endpoint(opts, opts.discovery.revocation_endpoint, body, opts.token_endpoint_auth_method, "revocation", true) + if err then + log(ERROR, "revocation of " .. token_type_log .. " unsuccessful: " .. err) + return false + else + log(DEBUG, "revocation of " .. token_type_log .. " successful") + return true + end +end + +function openidc.revoke_token(opts, token_type_hint, token) + local err = openidc_ensure_discovered_data(opts) + if err then + log(ERROR, "revocation of " .. (token_type_hint or "token (no type specified)") .. " unsuccessful: " .. err) + return false + end + + return openidc_revoke_token(opts, token_type_hint, token) +end + +function openidc.revoke_tokens(opts, session) + local err = openidc_ensure_discovered_data(opts) + if err then + log(ERROR, "revocation of tokens unsuccessful: " .. err) + return false + end + + local access_token = session.data.access_token + local refresh_token = session.data.refresh_token + + local access_token_revoke, refresh_token_revoke + if refresh_token then + access_token_revoke = openidc_revoke_token(opts, "refresh_token", refresh_token) + end + if access_token then + refresh_token_revoke = openidc_revoke_token(opts, "access_token", access_token) + end + return access_token_revoke and refresh_token_revoke +end + +local openidc_transparent_pixel = "\137\080\078\071\013\010\026\010\000\000\000\013\073\072\068\082" .. + "\000\000\000\001\000\000\000\001\008\004\000\000\000\181\028\012" .. + "\002\000\000\000\011\073\068\065\084\120\156\099\250\207\000\000" .. + "\002\007\001\002\154\028\049\113\000\000\000\000\073\069\078\068" .. + "\174\066\096\130" + +-- handle logout +local function openidc_logout(opts, session) + local session_token = session.data.enc_id_token + local access_token = session.data.access_token + local refresh_token = session.data.refresh_token + local err + + if opts.lifecycle and opts.lifecycle.on_logout then + err = opts.lifecycle.on_logout(session) + if err then + log(WARN, "failed in `on_logout` handler: " .. err) + return err + end + end + + session:destroy() + + if opts.revoke_tokens_on_logout then + log(DEBUG, "revoke_tokens_on_logout is enabled. " .. + "trying to revoke access and refresh tokens...") + if refresh_token then + openidc_revoke_token(opts, "refresh_token", refresh_token) + end + if access_token then + openidc_revoke_token(opts, "access_token", access_token) + end + end + + local headers = ngx.req.get_headers() + local header = get_first(headers['Accept']) + if header and header:find("image/png") then + ngx.header["Cache-Control"] = "no-cache, no-store" + ngx.header["Pragma"] = "no-cache" + ngx.header["P3P"] = "CAO PSA OUR" + ngx.header["Expires"] = "0" + ngx.header["X-Frame-Options"] = "DENY" + ngx.header.content_type = "image/png" + ngx.print(openidc_transparent_pixel) + ngx.exit(ngx.OK) + return + elseif opts.redirect_after_logout_uri or opts.discovery.end_session_endpoint then + local uri + if opts.redirect_after_logout_uri then + uri = opts.redirect_after_logout_uri + else + uri = opts.discovery.end_session_endpoint + end + local params = {} + if (opts.redirect_after_logout_with_id_token_hint or not opts.redirect_after_logout_uri) and session_token then + params["id_token_hint"] = session_token + end + if opts.post_logout_redirect_uri then + params["post_logout_redirect_uri"] = opts.post_logout_redirect_uri + end + return ngx.redirect(openidc_combine_uri(uri, params)) + elseif opts.discovery.ping_end_session_endpoint then + local params = {} + if opts.post_logout_redirect_uri then + params["TargetResource"] = opts.post_logout_redirect_uri + end + return ngx.redirect(openidc_combine_uri(opts.discovery.ping_end_session_endpoint, params)) + end + + ngx.header.content_type = "text/html" + ngx.say("<html><body>Logged Out</body></html>") + ngx.exit(ngx.OK) +end + +-- returns a valid access_token (eventually refreshing the token) +local function openidc_access_token(opts, session, try_to_renew) + + local err + + if session.data.access_token == nil then + return nil, err + end + local current_time = ngx.time() + if current_time < session.data.access_token_expiration then + return session.data.access_token, err + end + if not try_to_renew then + return nil, "token expired" + end + if session.data.refresh_token == nil then + return nil, "token expired and no refresh token available" + end + + log(DEBUG, "refreshing expired access_token: ", session.data.access_token, " with: ", session.data.refresh_token) + + -- retrieve token endpoint URL from discovery endpoint if necessary + err = ensure_config(opts) + if err then + return nil, err + end + + -- assemble the parameters to the token endpoint + local body = { + grant_type = "refresh_token", + refresh_token = session.data.refresh_token, + scope = opts.scope and opts.scope or "openid email profile" + } + + local json + json, err = openidc.call_token_endpoint(opts, opts.discovery.token_endpoint, body, opts.token_endpoint_auth_method) + if err then + return nil, err + end + local id_token + if json.id_token then + id_token, err = openidc_load_and_validate_jwt_id_token(opts, json.id_token, session) + if err then + log(ERROR, "invalid id token, discarding tokens returned while refreshing") + return nil, err + end + end + log(DEBUG, "access_token refreshed: ", json.access_token, " updated refresh_token: ", json.refresh_token) + + session.data.access_token = json.access_token + session.data.access_token_expiration = current_time + openidc_access_token_expires_in(opts, json.expires_in) + if json.refresh_token then + session.data.refresh_token = json.refresh_token + end + + if json.id_token and + (store_in_session(opts, 'enc_id_token') or store_in_session(opts, 'id_token')) then + log(DEBUG, "id_token refreshed: ", json.id_token) + if store_in_session(opts, 'enc_id_token') then + session.data.enc_id_token = json.id_token + end + if store_in_session(opts, 'id_token') then + session.data.id_token = id_token + end + end + + -- save the session with the new access_token and optionally the new refresh_token and id_token using a new sessionid + local regenerated + regenerated, err = session:regenerate() + if err then + log(ERROR, "failed to regenerate session: " .. err) + return nil, err + end + if opts.lifecycle and opts.lifecycle.on_regenerated then + err = opts.lifecycle.on_regenerated(session) + if err then + log(WARN, "failed in `on_regenerated` handler: " .. err) + return nil, err + end + end + + return session.data.access_token, err +end + +local function openidc_get_path(uri) + local without_query = uri:match("(.-)%?") or uri + return without_query:match(".-//[^/]+(/.*)") or without_query +end + +local function openidc_get_redirect_uri_path(opts) + return opts.redirect_uri and openidc_get_path(opts.redirect_uri) or opts.redirect_uri_path +end + +local function is_session(o) + return o ~= nil and o.start and type(o.start) == "function" +end + +-- main routine for OpenID Connect user authentication +function openidc.authenticate(opts, target_url, unauth_action, session_or_opts) + + if opts.redirect_uri_path then + log(WARN, "using deprecated option `opts.redirect_uri_path`; switch to using an absolute URI and `opts.redirect_uri` instead") + end + + local err + + local session + if is_session(session_or_opts) then + session = session_or_opts + else + local session_error + session, session_error = r_session.start(session_or_opts) + if session == nil then + log(ERROR, "Error starting session: " .. session_error) + return nil, session_error, target_url, session + end + end + + target_url = target_url or ngx.var.request_uri + + local access_token + + -- see if this is a request to the redirect_uri i.e. an authorization response + local path = openidc_get_path(target_url) + if path == openidc_get_redirect_uri_path(opts) then + log(DEBUG, "Redirect URI path (" .. path .. ") is currently navigated -> Processing authorization response coming from OP") + + if not session.present then + err = "request to the redirect_uri path but there's no session state found" + log(ERROR, err) + return nil, err, target_url, session + end + + return openidc_authorization_response(opts, session) + end + + -- see if this is a request to logout + if path == (opts.logout_path or "/logout") then + log(DEBUG, "Logout path (" .. path .. ") is currently navigated -> Processing local session removal before redirecting to next step of logout process") + + err = ensure_config(opts) + if err then + return nil, err, session.data.original_url, session + end + + openidc_logout(opts, session) + return nil, nil, target_url, session + end + + local token_expired = false + local try_to_renew = opts.renew_access_token_on_expiry == nil or opts.renew_access_token_on_expiry + if session.present and session.data.authenticated + and store_in_session(opts, 'access_token') then + + -- refresh access_token if necessary + access_token, err = openidc_access_token(opts, session, try_to_renew) + if err then + log(ERROR, "lost access token:" .. err) + err = nil + end + if not access_token then + token_expired = true + end + end + + log(DEBUG, + "session.present=", session.present, + ", session.data.id_token=", session.data.id_token ~= nil, + ", session.data.authenticated=", session.data.authenticated, + ", opts.force_reauthorize=", opts.force_reauthorize, + ", opts.renew_access_token_on_expiry=", opts.renew_access_token_on_expiry, + ", try_to_renew=", try_to_renew, + ", token_expired=", token_expired) + + -- if we are not authenticated then redirect to the OP for authentication + -- the presence of the id_token is check for backwards compatibility + if not session.present + or not (session.data.id_token or session.data.authenticated) + or opts.force_reauthorize + or (try_to_renew and token_expired) then + if unauth_action == "pass" then + if token_expired then + session.data.authenticated = false + return nil, 'token refresh failed', target_url, session + end + return nil, err, target_url, session + end + if unauth_action == 'deny' then + return nil, 'unauthorized request', target_url, session + end + + err = ensure_config(opts) + if err then + return nil, err, session.data.original_url, session + end + + log(DEBUG, "Authentication is required - Redirecting to OP Authorization endpoint") + openidc_authorize(opts, session, target_url, opts.prompt) + return nil, nil, target_url, session + end + + -- silently reauthenticate if necessary (mainly used for session refresh/getting updated id_token data) + if opts.refresh_session_interval ~= nil then + if session.data.last_authenticated == nil or (session.data.last_authenticated + opts.refresh_session_interval) < ngx.time() then + err = ensure_config(opts) + if err then + return nil, err, session.data.original_url, session + end + + log(DEBUG, "Silent authentication is required - Redirecting to OP Authorization endpoint") + openidc_authorize(opts, session, target_url, "none") + return nil, nil, target_url, session + end + end + + if store_in_session(opts, 'id_token') then + -- log id_token contents + log(DEBUG, "id_token=", cjson.encode(session.data.id_token)) + end + + -- return the id_token to the caller Lua script for access control purposes + return + { + id_token = session.data.id_token, + access_token = access_token, + user = session.data.user + }, + err, + target_url, + session +end + +-- get a valid access_token (eventually refreshing the token), or nil if there's no valid access_token +function openidc.access_token(opts, session_opts) + + local session = r_session.start(session_opts) + local token, err = openidc_access_token(opts, session, true) + session:close() + return token, err +end + + +-- get an OAuth 2.0 bearer access token from the HTTP request cookies +local function openidc_get_bearer_access_token_from_cookie(opts) + + local err + + log(DEBUG, "getting bearer access token from Cookie") + + local accept_token_as = opts.auth_accept_token_as or "header" + if accept_token_as:find("cookie") ~= 1 then + return nil, "openidc_get_bearer_access_token_from_cookie called but auth_accept_token_as wants " + .. opts.auth_accept_token_as + end + local divider = accept_token_as:find(':') + local cookie_name = divider and accept_token_as:sub(divider + 1) or "PA.global" + + log(DEBUG, "bearer access token from cookie named: " .. cookie_name) + + local cookies = ngx.req.get_headers()["Cookie"] + if not cookies then + err = "no Cookie header found" + log(ERROR, err) + return nil, err + end + + local cookie_value = ngx.var["cookie_" .. cookie_name] + if not cookie_value then + err = "no Cookie " .. cookie_name .. " found" + log(ERROR, err) + end + + return cookie_value, err +end + + +-- get an OAuth 2.0 bearer access token from the HTTP request +local function openidc_get_bearer_access_token(opts) + + local err + + local accept_token_as = opts.auth_accept_token_as or "header" + + if accept_token_as:find("cookie") == 1 then + return openidc_get_bearer_access_token_from_cookie(opts) + end + + -- get the access token from the Authorization header + local headers = ngx.req.get_headers() + local header_name = opts.auth_accept_token_as_header_name or "Authorization" + local header = get_first(headers[header_name]) + + if header == nil or header:find(" ") == nil then + err = "no Authorization header found" + log(ERROR, err) + return nil, err + end + + local divider = header:find(' ') + if string.lower(header:sub(0, divider - 1)) ~= string.lower("Bearer") then + err = "no Bearer authorization header value found" + log(ERROR, err) + return nil, err + end + + local access_token = header:sub(divider + 1) + if access_token == nil then + err = "no Bearer access token value found" + log(ERROR, err) + return nil, err + end + + return access_token, err +end + +local function get_introspection_endpoint(opts) + local introspection_endpoint = opts.introspection_endpoint + if not introspection_endpoint then + local err = openidc_ensure_discovered_data(opts) + if err then + return nil, "opts.introspection_endpoint not said and " .. err + end + local endpoint = opts.discovery and opts.discovery.introspection_endpoint + if endpoint then + return endpoint + end + end + return introspection_endpoint +end + +local function get_introspection_cache_prefix(opts) + return (opts.cache_segment and opts.cache_segment.gsub(',', '_') or 'DEFAULT') .. ',' + .. (get_introspection_endpoint(opts) or 'nil-endpoint') .. ',' + .. (opts.client_id or 'no-client_id') .. ',' + .. (opts.client_secret and 'secret' or 'no-client_secret') .. ':' +end + +local function get_cached_introspection(opts, access_token) + local introspection_cache_ignore = opts.introspection_cache_ignore or false + if not introspection_cache_ignore then + return openidc_cache_get("introspection", + get_introspection_cache_prefix(opts) .. access_token) + end +end + +local function set_cached_introspection(opts, access_token, encoded_json, ttl) + local introspection_cache_ignore = opts.introspection_cache_ignore or false + if not introspection_cache_ignore then + openidc_cache_set("introspection", + get_introspection_cache_prefix(opts) .. access_token, + encoded_json, ttl) + end +end + +-- main routine for OAuth 2.0 token introspection +function openidc.introspect(opts) + + -- get the access token from the request + local access_token, err = openidc_get_bearer_access_token(opts) + if access_token == nil then + return nil, err + end + + -- see if we've previously cached the introspection result for this access token + local json + local v = get_cached_introspection(opts, access_token) + + if v then + json = cjson.decode(v) + return json, err + end + + -- assemble the parameters to the introspection (token) endpoint + local token_param_name = opts.introspection_token_param_name and opts.introspection_token_param_name or "token" + + local body = {} + + body[token_param_name] = access_token + + if opts.client_id then + body.client_id = opts.client_id + end + if opts.client_secret then + body.client_secret = opts.client_secret + end + + -- merge any provided extra parameters + if opts.introspection_params then + for key, val in pairs(opts.introspection_params) do body[key] = val end + end + + -- call the introspection endpoint + local introspection_endpoint + introspection_endpoint, err = get_introspection_endpoint(opts) + if err then + return nil, err + end + json, err = openidc.call_token_endpoint(opts, introspection_endpoint, body, opts.introspection_endpoint_auth_method, "introspection") + + + if not json then + return json, err + end + + if not json.active then + err = "invalid token" + return json, err + end + + -- cache the results + local introspection_cache_ignore = opts.introspection_cache_ignore or false + local expiry_claim = opts.introspection_expiry_claim or "exp" + + if not introspection_cache_ignore and json[expiry_claim] then + local introspection_interval = opts.introspection_interval or 0 + local ttl = json[expiry_claim] + if expiry_claim == "exp" then --https://tools.ietf.org/html/rfc7662#section-2.2 + ttl = ttl - ngx.time() + end + if introspection_interval > 0 then + if ttl > introspection_interval then + ttl = introspection_interval + end + end + log(DEBUG, "cache token ttl: " .. ttl) + set_cached_introspection(opts, access_token, cjson.encode(json), ttl) + end + + return json, err + +end + +local function get_jwt_verification_cache_prefix(opts) + local signing_alg_values_expected = (opts.accept_none_alg and 'none' or 'no-none') + local expected_algs = opts.token_signing_alg_values_expected or {} + if type(expected_algs) == 'string' then + expected_algs = { expected_algs } + end + for _, alg in ipairs(expected_algs) do + signing_alg_values_expected = signing_alg_values_expected .. ',' .. alg + end + return (opts.cache_segment and opts.cache_segment.gsub(',', '_') or 'DEFAULT') .. ',' + .. (opts.public_key or 'no-pubkey') .. ',' + .. (opts.symmetric_key or 'no-symkey') .. ',' + .. signing_alg_values_expected .. ':' +end + +local function get_cached_jwt_verification(opts, access_token) + local jwt_verification_cache_ignore = opts.jwt_verification_cache_ignore or false + if not jwt_verification_cache_ignore then + return openidc_cache_get("jwt_verification", + get_jwt_verification_cache_prefix(opts) .. access_token) + end +end + +local function set_cached_jwt_verification(opts, access_token, encoded_json, ttl) + local jwt_verification_cache_ignore = opts.jwt_verification_cache_ignore or false + if not jwt_verification_cache_ignore then + openidc_cache_set("jwt_verification", + get_jwt_verification_cache_prefix(opts) .. access_token, + encoded_json, ttl) + end +end + +-- main routine for OAuth 2.0 JWT token validation +-- optional args are claim specs, see jwt-validators in resty.jwt +function openidc.jwt_verify(access_token, opts, ...) + local err + local json + local v = get_cached_jwt_verification(opts, access_token) + + local slack = opts.iat_slack and opts.iat_slack or 120 + if not v then + local jwt_obj + jwt_obj, err = openidc_load_jwt_and_verify_crypto(opts, access_token, opts.public_key, opts.symmetric_key, + opts.token_signing_alg_values_expected, ...) + if not err then + json = jwt_obj.payload + local encoded_json = cjson.encode(json) + log(DEBUG, "jwt: ", encoded_json) + + set_cached_jwt_verification(opts, access_token, encoded_json, + json.exp and json.exp - ngx.time() or 120) + end + + else + -- decode from the cache + json = cjson.decode(v) + end + + -- check the token expiry + if json then + if json.exp and json.exp + slack < ngx.time() then + log(ERROR, "token expired: json.exp=", json.exp, ", ngx.time()=", ngx.time()) + err = "JWT expired" + end + end + + return json, err +end + +function openidc.bearer_jwt_verify(opts, ...) + local json + + -- get the access token from the request + local access_token, err = openidc_get_bearer_access_token(opts) + if access_token == nil then + return nil, err + end + + log(DEBUG, "access_token: ", access_token) + + json, err = openidc.jwt_verify(access_token, opts, ...) + return json, err, access_token +end + +-- Passing nil to any of the arguments resets the configuration to default +function openidc.set_logging(new_log, new_levels) + log = new_log and new_log or ngx.log + DEBUG = new_levels.DEBUG and new_levels.DEBUG or ngx.DEBUG + ERROR = new_levels.ERROR and new_levels.ERROR or ngx.ERR + WARN = new_levels.WARN and new_levels.WARN or ngx.WARN +end + +return openidc diff --git a/server/resty/openssl.lua b/server/resty/openssl.lua new file mode 100644 index 0000000..27ef5cc --- /dev/null +++ b/server/resty/openssl.lua @@ -0,0 +1,476 @@ +local ffi = require("ffi") +local C = ffi.C +local ffi_cast = ffi.cast +local ffi_str = ffi.string + +local format_error = require("resty.openssl.err").format_error + +local OPENSSL_3X, BORINGSSL + +local function try_require_modules() + package.loaded["resty.openssl.version"] = nil + + local pok, lib = pcall(require, "resty.openssl.version") + if pok then + OPENSSL_3X = lib.OPENSSL_3X + BORINGSSL = lib.BORINGSSL + + require "resty.openssl.include.crypto" + require "resty.openssl.include.objects" + else + package.loaded["resty.openssl.version"] = nil + end +end +try_require_modules() + + +local _M = { + _VERSION = '0.8.16', +} + +local libcrypto_name +local lib_patterns = { + "%s", "%s.so.3", "%s.so.1.1", "%s.so.1.0" +} + +function _M.load_library() + for _, pattern in ipairs(lib_patterns) do + -- true: load to global namespae + local pok, _ = pcall(ffi.load, string.format(pattern, "crypto"), true) + if pok then + libcrypto_name = string.format(pattern, "crypto") + ffi.load(string.format(pattern, "ssl"), true) + + try_require_modules() + + return libcrypto_name + end + end + + return false, "unable to load crypto library" +end + +function _M.load_modules() + _M.bn = require("resty.openssl.bn") + _M.cipher = require("resty.openssl.cipher") + _M.digest = require("resty.openssl.digest") + _M.hmac = require("resty.openssl.hmac") + _M.kdf = require("resty.openssl.kdf") + _M.pkey = require("resty.openssl.pkey") + _M.objects = require("resty.openssl.objects") + _M.rand = require("resty.openssl.rand") + _M.version = require("resty.openssl.version") + _M.x509 = require("resty.openssl.x509") + _M.altname = require("resty.openssl.x509.altname") + _M.chain = require("resty.openssl.x509.chain") + _M.csr = require("resty.openssl.x509.csr") + _M.crl = require("resty.openssl.x509.crl") + _M.extension = require("resty.openssl.x509.extension") + _M.extensions = require("resty.openssl.x509.extensions") + _M.name = require("resty.openssl.x509.name") + _M.revoked = require("resty.openssl.x509.revoked") + _M.store = require("resty.openssl.x509.store") + _M.pkcs12 = require("resty.openssl.pkcs12") + _M.ssl = require("resty.openssl.ssl") + _M.ssl_ctx = require("resty.openssl.ssl_ctx") + + if OPENSSL_3X then + _M.provider = require("resty.openssl.provider") + _M.mac = require("resty.openssl.mac") + _M.ctx = require("resty.openssl.ctx") + end + + _M.bignum = _M.bn +end + +function _M.luaossl_compat() + _M.load_modules() + + _M.csr.setSubject = _M.csr.set_subject_name + _M.csr.setPublicKey = _M.csr.set_pubkey + + _M.x509.setPublicKey = _M.x509.set_pubkey + _M.x509.getPublicKey = _M.x509.get_pubkey + _M.x509.setSerial = _M.x509.set_serial_number + _M.x509.getSerial = _M.x509.get_serial_number + _M.x509.setSubject = _M.x509.set_subject_name + _M.x509.getSubject = _M.x509.get_subject_name + _M.x509.setIssuer = _M.x509.set_issuer_name + _M.x509.getIssuer = _M.x509.get_issuer_name + _M.x509.getOCSP = _M.x509.get_ocsp_url + + local pkey_new = _M.pkey.new + _M.pkey.new = function(a, b) + if type(a) == "string" then + return pkey_new(a, b and unpack(b)) + else + return pkey_new(a, b) + end + end + + _M.cipher.encrypt = function(self, key, iv, padding) + return self, _M.cipher.init(self, key, iv, true, not padding) + end + _M.cipher.decrypt = function(self, key, iv, padding) + return self, _M.cipher.init(self, key, iv, false, not padding) + end + + local digest_update = _M.digest.update + _M.digest.update = function(self, ...) + local ok, err = digest_update(self, ...) + if ok then + return self + else + return nil, err + end + end + + local store_verify = _M.store.verify + _M.store.verify = function(...) + local ok, err = store_verify(...) + if err then + return false, err + else + return true, ok + end + end + + local kdf_derive = _M.kdf.derive + local kdf_keys_mappings = { + iter = "pbkdf2_iter", + key = "hkdf_key", + info = "hkdf_info", + secret = "tls1_prf_secret", + seed = "tls1_prf_seed", + maxmem_bytes = "scrypt_maxmem", + N = "scrypt_N", + r = "scrypt_r", + p = "scrypt_p", + } + _M.kdf.derive = function(o) + for k1, k2 in pairs(kdf_keys_mappings) do + o[k1] = o[k2] + o[k2] = nil + end + local hkdf_mode = o.hkdf_mode + if hkdf_mode == "extract_and_expand" then + o.hkdf_mode = _M.kdf.HKDEF_MODE_EXTRACT_AND_EXPAND + elseif hkdf_mode == "extract_only" then + o.hkdf_mode = _M.kdf.HKDEF_MODE_EXTRACT_ONLY + elseif hkdf_mode == "expand_only" then + o.hkdf_mode = _M.kdf.HKDEF_MODE_EXPAND_ONLY + end + return kdf_derive(o) + end + + _M.pkcs12.new = function(tbl) + local certs = {} + local passphrase = tbl.passphrase + if not tbl.key then + return nil, "key must be set" + end + for _, cert in ipairs(tbl.certs) do + if not _M.x509.istype(cert) then + return nil, "certs must contains only x509 instance" + end + if cert:check_private_key(tbl.key) then + tbl.cert = cert + else + certs[#certs+1] = cert + end + end + tbl.cacerts = certs + return _M.pkcs12.encode(tbl, passphrase) + end + + _M.crl.add = _M.crl.add_revoked + _M.crl.lookupSerial = _M.crl.get_by_serial + + for mod, tbl in pairs(_M) do + if type(tbl) == 'table' then + + -- avoid using a same table as the iterrator will change + local new_tbl = {} + -- luaossl always error() out + for k, f in pairs(tbl) do + if type(f) == 'function' then + local of = f + new_tbl[k] = function(...) + local ret = { of(...) } + if ret and #ret > 1 and ret[#ret] then + error(mod .. "." .. k .. "(): " .. ret[#ret]) + end + return unpack(ret) + end + end + end + + for k, f in pairs(new_tbl) do + tbl[k] = f + end + + setmetatable(tbl, { + __index = function(t, k) + local tok + -- handle special case + if k == 'toPEM' then + tok = 'to_PEM' + else + tok = k:gsub("(%l)(%u)", function(a, b) return a .. "_" .. b:lower() end) + if tok == k then + return + end + end + if type(tbl[tok]) == 'function' then + return tbl[tok] + end + end + }) + end + end + + -- skip error() conversion + _M.pkcs12.parse = function(p12, passphrase) + local r, err = _M.pkcs12.decode(p12, passphrase) + if err then error(err) end + return r.key, r.cert, r.cacerts + end +end + +if OPENSSL_3X then + require "resty.openssl.include.evp" + local provider = require "resty.openssl.provider" + local ctx_lib = require "resty.openssl.ctx" + local fips_provider_ctx + + function _M.set_fips_mode(enable, self_test) + if (not not enable) == _M.get_fips_mode() then + return true + end + + if enable then + local p, err = provider.load("fips") + if not p then + return false, err + end + fips_provider_ctx = p + if self_test then + local ok, err = p:self_test() + if not ok then + return false, err + end + end + + elseif fips_provider_ctx then -- disable + local p = fips_provider_ctx + fips_provider_ctx = nil + return p:unload() + end + + -- set algorithm in fips mode in default ctx + -- this deny/allow non-FIPS compliant algorithms to be used from EVP interface + -- and redirect/remove redirect implementation to fips provider + if C.EVP_default_properties_enable_fips(ctx_lib.get_libctx(), enable and 1 or 0) == 0 then + return false, format_error("openssl.set_fips_mode: EVP_default_properties_enable_fips") + end + + return true + end + + function _M.get_fips_mode() + local pok = provider.is_available("fips") + if not pok then + return false + end + + return C.EVP_default_properties_is_fips_enabled(ctx_lib.get_libctx()) == 1 + end + +else + function _M.set_fips_mode(enable) + if (not not enable) == _M.get_fips_mode() then + return true + end + + if C.FIPS_mode_set(enable and 1 or 0) == 0 then + return false, format_error("openssl.set_fips_mode") + end + + return true + end + + function _M.get_fips_mode() + return C.FIPS_mode() == 1 + end +end + +function _M.set_default_properties(props) + if not OPENSSL_3X then + return nil, "openssl.set_default_properties is only not supported from OpenSSL 3.0" + end + + local ctx_lib = require "resty.openssl.ctx" + + if C.EVP_set_default_properties(ctx_lib.get_libctx(), props) == 0 then + return false, format_error("openssl.EVP_set_default_properties") + end + + return true +end + +local function list_legacy(typ, get_nid_cf) + local typ_lower = string.lower(typ:sub(5)) -- cut off EVP_ + require ("resty.openssl.include.evp." .. typ_lower) + + local ret = {} + local fn = ffi_cast("fake_openssl_" .. typ_lower .. "_list_fn*", + function(elem, from, to, arg) + if elem ~= nil then + local nid = get_nid_cf(elem) + table.insert(ret, ffi_str(C.OBJ_nid2sn(nid))) + end + -- from/to (renamings) are ignored + end) + C[typ .. "_do_all_sorted"](fn, nil) + fn:free() + + return ret +end + +local function list_provided(typ) + local typ_lower = string.lower(typ:sub(5)) -- cut off EVP_ + local typ_ptr = typ .. "*" + require ("resty.openssl.include.evp." .. typ_lower) + local ctx_lib = require "resty.openssl.ctx" + + local ret = {} + + local fn = ffi_cast("fake_openssl_" .. typ_lower .. "_provided_list_fn*", + function(elem, _) + elem = ffi_cast(typ_ptr, elem) + local name = ffi_str(C[typ .. "_get0_name"](elem)) + -- alternate names are ignored, retrieve use TYPE_names_do_all + local prov = ffi_str(C.OSSL_PROVIDER_get0_name(C[typ .. "_get0_provider"](elem))) + table.insert(ret, name .. " @ " .. prov) + end) + + C[typ .. "_do_all_provided"](ctx_lib.get_libctx(), fn, nil) + fn:free() + + table.sort(ret) + return ret +end + +function _M.list_cipher_algorithms() + if BORINGSSL then + return nil, "openssl.list_cipher_algorithms is not supported on BoringSSL" + end + + require "resty.openssl.include.evp.cipher" + local ret = list_legacy("EVP_CIPHER", + OPENSSL_3X and C.EVP_CIPHER_get_nid or C.EVP_CIPHER_nid) + + if OPENSSL_3X then + local ret_provided = list_provided("EVP_CIPHER") + for _, r in ipairs(ret_provided) do + table.insert(ret, r) + end + end + + return ret +end + +function _M.list_digest_algorithms() + if BORINGSSL then + return nil, "openssl.list_digest_algorithms is not supported on BoringSSL" + end + + require "resty.openssl.include.evp.md" + local ret = list_legacy("EVP_MD", + OPENSSL_3X and C.EVP_MD_get_type or C.EVP_MD_type) + + if OPENSSL_3X then + local ret_provided = list_provided("EVP_MD") + for _, r in ipairs(ret_provided) do + table.insert(ret, r) + end + end + + return ret +end + +function _M.list_mac_algorithms() + if not OPENSSL_3X then + return nil, "openssl.list_mac_algorithms is only supported from OpenSSL 3.0" + end + + return list_provided("EVP_MAC") +end + +function _M.list_kdf_algorithms() + if not OPENSSL_3X then + return nil, "openssl.list_kdf_algorithms is only supported from OpenSSL 3.0" + end + + return list_provided("EVP_KDF") +end + +local valid_ssl_protocols = { + ["SSLv3"] = 0x0300, + ["TLSv1"] = 0x0301, + ["TLSv1.1"] = 0x0302, + ["TLSv1.2"] = 0x0303, + ["TLSv1.3"] = 0x0304, +} + +function _M.list_ssl_ciphers(cipher_list, ciphersuites, protocol) + local ssl_lib = require("resty.openssl.ssl") + local ssl_macro = require("resty.openssl.include.ssl") + + if protocol then + if not valid_ssl_protocols[protocol] then + return nil, "unknown protocol \"" .. protocol .. "\"" + end + protocol = valid_ssl_protocols[protocol] + end + + local ssl_ctx = C.SSL_CTX_new(C.TLS_server_method()) + if ssl_ctx == nil then + return nil, format_error("SSL_CTX_new") + end + ffi.gc(ssl_ctx, C.SSL_CTX_free) + + local ssl = C.SSL_new(ssl_ctx) + if ssl == nil then + return nil, format_error("SSL_new") + end + ffi.gc(ssl, C.SSL_free) + + if protocol then + if ssl_macro.SSL_set_min_proto_version(ssl, protocol) == 0 or + ssl_macro.SSL_set_max_proto_version(ssl, protocol) == 0 then + return nil, format_error("SSL_set_min/max_proto_version") + end + end + + ssl = { ctx = ssl } + + local ok, err + if cipher_list then + ok, err = ssl_lib.set_cipher_list(ssl, cipher_list) + if not ok then + return nil, err + end + end + + if ciphersuites then + ok, err = ssl_lib.set_ciphersuites(ssl, ciphersuites) + if not ok then + return nil, err + end + end + + return ssl_lib.get_ciphers(ssl) +end + +return _M diff --git a/server/resty/openssl/asn1.lua b/server/resty/openssl/asn1.lua new file mode 100644 index 0000000..0fa0605 --- /dev/null +++ b/server/resty/openssl/asn1.lua @@ -0,0 +1,91 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_str = ffi.string +local floor = math.floor + +local asn1_macro = require("resty.openssl.include.asn1") + +-- https://github.com/wahern/luaossl/blob/master/src/openssl.c +local function isleap(year) + return (year % 4) == 0 and ((year % 100) > 0 or (year % 400) == 0) +end + +local past = { 0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334 } +local function yday(year, mon, mday) + local d = past[mon] + mday - 1 + if mon > 2 and isleap(year) then + d = d + 1 + end + return d +end + +local function leaps(year) + return floor(year / 400) + floor(year / 4) - floor(year / 100) +end + +local function asn1_to_unix(asn1) + if asn1 == nil then + return nil, "except an ASN1 instance at #1, got nil" + end + + local s = asn1_macro.ASN1_STRING_get0_data(asn1) + s = ffi_str(s) + -- V_ASN1_UTCTIME 190303223958Z + -- V_ASN1_GENERALIZEDTIME 21190822162753Z + local yyoffset = 2 + local year + -- # define V_ASN1_GENERALIZEDTIME 24 + if C.ASN1_STRING_type(asn1) == 24 then + yyoffset = 4 + year = tonumber(s:sub(1, yyoffset)) + else + year = tonumber(s:sub(1, yyoffset)) + year = year + (year < 50 and 2000 or 1900) + end + local month = tonumber(s:sub(yyoffset+1, yyoffset+2)) + if month > 12 or month < 1 then + return nil, "asn1.asn1_to_unix: bad format " .. s + end + local day = tonumber(s:sub(yyoffset+3, yyoffset+4)) + if day > 31 or day < 1 then + return nil, "asn1.asn1_to_unix: bad format " .. s + end + local hour = tonumber(s:sub(yyoffset+5, yyoffset+6)) + if hour > 23 or hour < 0 then + return nil, "asn1.asn1_to_unix: bad format " .. s + end + local minute = tonumber(s:sub(yyoffset+7, yyoffset+8)) + if minute > 59 or hour < 0 then + return nil, "asn1.asn1_to_unix: bad format " .. s + end + local second = tonumber(s:sub(yyoffset+9, yyoffset+10)) + if second > 59 or second < 0 then + return nil, "asn1.asn1_to_unix: bad format " .. s + end + + local tm + tm = (year - 1970) * 365 + tm = tm + leaps(year - 1) - leaps(1969) + tm = (tm + yday(year, month, day)) * 24 + tm = (tm + hour) * 60 + tm = (tm + minute) * 60 + tm = tm + second + + -- offset? + local sign = s:sub(yyoffset+11, yyoffset+11) + if sign == "+" or sign == "-" then + local sgn = sign == "+" and 1 or -1 + local hh = tonumber(s:sub(yyoffset+12, yyoffset+13) or 'no') + local mm = tonumber(s:sub(yyoffset+14, yyoffset+15) or 'no') + if not hh or not mm then + return nil, "asn1.asn1_to_unix: bad format " .. s + end + tm = tm + sgn * (hh * 3600 + mm * 60) + end + + return tm +end + +return { + asn1_to_unix = asn1_to_unix, +} diff --git a/server/resty/openssl/auxiliary/bio.lua b/server/resty/openssl/auxiliary/bio.lua new file mode 100644 index 0000000..3eed9f0 --- /dev/null +++ b/server/resty/openssl/auxiliary/bio.lua @@ -0,0 +1,43 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc +local ffi_new = ffi.new +local ffi_str = ffi.string + +require "resty.openssl.include.bio" +local format_error = require("resty.openssl.err").format_error + +local function read_wrap(f, ...) + if type(f) ~= "cdata" then -- should be explictly a function + return nil, "bio_util.read_wrap: expect a function at #1" + end + + local bio_method = C.BIO_s_mem() + if bio_method == nil then + return nil, "bio_util.read_wrap: BIO_s_mem() failed" + end + local bio = C.BIO_new(bio_method) + ffi_gc(bio, C.BIO_free) + + -- BIO_reset; #define BIO_CTRL_RESET 1 + local code = C.BIO_ctrl(bio, 1, 0, nil) + if code ~= 1 then + return nil, "bio_util.read_wrap: BIO_ctrl() failed: " .. code + end + + local code = f(bio, ...) + if code ~= 1 then + return nil, format_error(f, code) + end + + local buf = ffi_new("char *[1]") + + -- BIO_get_mem_data; #define BIO_CTRL_INFO 3 + local length = C.BIO_ctrl(bio, 3, 0, buf) + + return ffi_str(buf[0], length) +end + +return { + read_wrap = read_wrap, +}
\ No newline at end of file diff --git a/server/resty/openssl/auxiliary/ctypes.lua b/server/resty/openssl/auxiliary/ctypes.lua new file mode 100644 index 0000000..933822b --- /dev/null +++ b/server/resty/openssl/auxiliary/ctypes.lua @@ -0,0 +1,28 @@ +-- Put common type definition at the same place for convenience +-- and standarlization +local ffi = require "ffi" + +--[[ + TYPE_ptr: usually used to define a pointer (to cast or something) + char* var_name; // <- we use char_ptr + + ptr_of_TYPE: usually used to pass the pointer of an object that + is already allocated. so that we can also set value of it as well + + int p = 2; // ptr_of_int(); ptr_of_int[0] = 2; + plus_one(&p); // <- we use ptr_of_int +]] + +return { + void_ptr = ffi.typeof("void *"), + ptr_of_uint64 = ffi.typeof("uint64_t[1]"), + ptr_of_uint = ffi.typeof("unsigned int[1]"), + ptr_of_size_t = ffi.typeof("size_t[1]"), + ptr_of_int = ffi.typeof("int[1]"), + null = ffi.new("void *"), -- hack wher ngx.null is not available + + uchar_array = ffi.typeof("unsigned char[?]"), + uchar_ptr = ffi.typeof("unsigned char*"), + + SIZE_MAX = math.pow(2, 64), -- nginx set _FILE_OFFSET_BITS to 64 +}
\ No newline at end of file diff --git a/server/resty/openssl/auxiliary/jwk.lua b/server/resty/openssl/auxiliary/jwk.lua new file mode 100644 index 0000000..5a505a9 --- /dev/null +++ b/server/resty/openssl/auxiliary/jwk.lua @@ -0,0 +1,261 @@ + +local ffi = require "ffi" +local C = ffi.C + +local cjson = require("cjson.safe") +local b64 = require("ngx.base64") + +local evp_macro = require "resty.openssl.include.evp" +local rsa_lib = require "resty.openssl.rsa" +local ec_lib = require "resty.openssl.ec" +local ecx_lib = require "resty.openssl.ecx" +local bn_lib = require "resty.openssl.bn" +local digest_lib = require "resty.openssl.digest" + +local _M = {} + +local rsa_jwk_params = {"n", "e", "d", "p", "q", "dp", "dq", "qi"} +local rsa_openssl_params = rsa_lib.params + +local function load_jwk_rsa(tbl) + if not tbl["n"] or not tbl["e"] then + return nil, "at least \"n\" and \"e\" parameter is required" + end + + local params = {} + local err + for i, k in ipairs(rsa_jwk_params) do + local v = tbl[k] + if v then + v = b64.decode_base64url(v) + if not v then + return nil, "cannot decode parameter \"" .. k .. "\" from base64 " .. tbl[k] + end + + params[rsa_openssl_params[i]], err = bn_lib.from_binary(v) + if err then + return nil, "cannot use parameter \"" .. k .. "\": " .. err + end + end + end + + local key = C.RSA_new() + if key == nil then + return nil, "RSA_new() failed" + end + + local _, err = rsa_lib.set_parameters(key, params) + if err ~= nil then + C.RSA_free(key) + return nil, err + end + + return key +end + +local ec_curves = { + ["P-256"] = C.OBJ_ln2nid("prime256v1"), + ["P-384"] = C.OBJ_ln2nid("secp384r1"), + ["P-521"] = C.OBJ_ln2nid("secp521r1"), +} + +local ec_curves_reverse = {} +for k, v in pairs(ec_curves) do + ec_curves_reverse[v] = k +end + +local ec_jwk_params = {"x", "y", "d"} + +local function load_jwk_ec(tbl) + local curve = tbl['crv'] + if not curve then + return nil, "\"crv\" not defined for EC key" + end + if not tbl["x"] or not tbl["y"] then + return nil, "at least \"x\" and \"y\" parameter is required" + end + local curve_nid = ec_curves[curve] + if not curve_nid then + return nil, "curve \"" .. curve .. "\" is not supported by this library" + elseif curve_nid == 0 then + return nil, "curve \"" .. curve .. "\" is not supported by linked OpenSSL" + end + + local params = {} + local err + for _, k in ipairs(ec_jwk_params) do + local v = tbl[k] + if v then + v = b64.decode_base64url(v) + if not v then + return nil, "cannot decode parameter \"" .. k .. "\" from base64 " .. tbl[k] + end + + params[k], err = bn_lib.from_binary(v) + if err then + return nil, "cannot use parameter \"" .. k .. "\": " .. err + end + end + end + + -- map to the name we expect + if params["d"] then + params["private"] = params["d"] + params["d"] = nil + end + params["group"] = curve_nid + + local key = C.EC_KEY_new() + if key == nil then + return nil, "EC_KEY_new() failed" + end + + local _, err = ec_lib.set_parameters(key, params) + if err ~= nil then + C.EC_KEY_free(key) + return nil, err + end + + return key +end + +local function load_jwk_okp(key_type, tbl) + local params = {} + if tbl["d"] then + params.private = b64.decode_base64url(tbl["d"]) + elseif tbl["x"] then + params.public = b64.decode_base64url(tbl["x"]) + else + return nil, "at least \"x\" or \"d\" parameter is required" + end + local key, err = ecx_lib.set_parameters(key_type, nil, params) + if err ~= nil then + return nil, err + end + return key +end + +local ecx_curves_reverse = {} +for k, v in pairs(evp_macro.ecx_curves) do + ecx_curves_reverse[v] = k +end + +function _M.load_jwk(txt) + local tbl, err = cjson.decode(txt) + if err then + return nil, "error decoding JSON from JWK: " .. err + elseif type(tbl) ~= "table" then + return nil, "except input to be decoded as a table, got " .. type(tbl) + end + + local key, key_free, key_type, err + + if tbl["kty"] == "RSA" then + key_type = evp_macro.EVP_PKEY_RSA + if key_type == 0 then + return nil, "the linked OpenSSL library doesn't support RSA key" + end + key, err = load_jwk_rsa(tbl) + key_free = C.RSA_free + elseif tbl["kty"] == "EC" then + key_type = evp_macro.EVP_PKEY_EC + if key_type == 0 then + return nil, "the linked OpenSSL library doesn't support EC key" + end + key, err = load_jwk_ec(tbl) + key_free = C.EC_KEY_free + elseif tbl["kty"] == "OKP" then + local curve = tbl["crv"] + key_type = evp_macro.ecx_curves[curve] + if not key_type then + return nil, "unknown curve \"" .. tostring(curve) + elseif key_type == 0 then + return nil, "the linked OpenSSL library doesn't support \"" .. curve .. "\" key" + end + key, err = load_jwk_okp(key_type, tbl) + if key ~= nil then + return key + end + else + return nil, "not yet supported jwk type \"" .. (tbl["kty"] or "nil") .. "\"" + end + + if err then + return nil, "failed to construct " .. tbl["kty"] .. " key from JWK: " .. err + end + + local ctx = C.EVP_PKEY_new() + if ctx == nil then + key_free(key) + return nil, "EVP_PKEY_new() failed" + end + + local code = C.EVP_PKEY_assign(ctx, key_type, key) + if code ~= 1 then + key_free(key) + C.EVP_PKEY_free(ctx) + return nil, "EVP_PKEY_assign() failed" + end + + return ctx +end + +function _M.dump_jwk(pkey, is_priv) + local jwk + if pkey.key_type == evp_macro.EVP_PKEY_RSA then + local param_keys = { "n" , "e" } + if is_priv then + param_keys = rsa_jwk_params + end + local params, err = pkey:get_parameters() + if err then + return nil, "jwk.dump_jwk: " .. err + end + jwk = { + kty = "RSA", + } + for i, p in ipairs(param_keys) do + local v = params[rsa_openssl_params[i]]:to_binary() + jwk[p] = b64.encode_base64url(v) + end + elseif pkey.key_type == evp_macro.EVP_PKEY_EC then + local params, err = pkey:get_parameters() + if err then + return nil, "jwk.dump_jwk: " .. err + end + jwk = { + kty = "EC", + crv = ec_curves_reverse[params.group], + x = b64.encode_base64url(params.x:to_binary()), + y = b64.encode_base64url(params.x:to_binary()), + } + if is_priv then + jwk.d = b64.encode_base64url(params.private:to_binary()) + end + elseif ecx_curves_reverse[pkey.key_type] then + local params, err = pkey:get_parameters() + if err then + return nil, "jwk.dump_jwk: " .. err + end + jwk = { + kty = "OKP", + crv = ecx_curves_reverse[pkey.key_type], + d = b64.encode_base64url(params.private), + x = b64.encode_base64url(params.public), + } + else + return nil, "jwk.dump_jwk: not implemented for this key type" + end + + local der = pkey:tostring(is_priv and "private" or "public", "DER") + local dgst = digest_lib.new("sha256") + local d, err = dgst:final(der) + if err then + return nil, "jwk.dump_jwk: failed to calculate digest for key" + end + jwk.kid = b64.encode_base64url(d) + + return cjson.encode(jwk) +end + +return _M diff --git a/server/resty/openssl/auxiliary/nginx.lua b/server/resty/openssl/auxiliary/nginx.lua new file mode 100644 index 0000000..8adeceb --- /dev/null +++ b/server/resty/openssl/auxiliary/nginx.lua @@ -0,0 +1,318 @@ +local get_req_ssl, get_req_ssl_ctx +local get_socket_ssl, get_socket_ssl_ctx + +local pok, nginx_c = pcall(require, "resty.openssl.auxiliary.nginx_c") + +if pok and not os.getenv("CI_SKIP_NGINX_C") then + get_req_ssl = nginx_c.get_req_ssl + get_req_ssl_ctx = nginx_c.get_req_ssl_ctx + get_socket_ssl = nginx_c.get_socket_ssl + get_socket_ssl_ctx = nginx_c.get_socket_ssl +else + local ffi = require "ffi" + + ffi.cdef [[ + // Nginx seems to always config _FILE_OFFSET_BITS=64, this should always be 8 byte + typedef long long off_t; + typedef unsigned int socklen_t; // windows uses int, same size + typedef unsigned short in_port_t; + + typedef struct ssl_st SSL; + typedef struct ssl_ctx_st SSL_CTX; + + typedef long (*ngx_recv_pt)(void *c, void *buf, size_t size); + typedef long (*ngx_recv_chain_pt)(void *c, void *in, + off_t limit); + typedef long (*ngx_send_pt)(void *c, void *buf, size_t size); + typedef void *(*ngx_send_chain_pt)(void *c, void *in, + off_t limit); + + typedef struct { + size_t len; + void *data; + } ngx_str_t; + + typedef struct { + SSL *connection; + SSL_CTX *session_ctx; + // trimmed + } ngx_ssl_connection_s; + ]] + + local ngx_version = ngx.config.nginx_version + if ngx_version == 1017008 or ngx_version == 1019003 or ngx_version == 1019009 + or ngx_version == 1021004 then + -- 1.17.8, 1.19.3, 1.19.9, 1.21.4 + -- https://github.com/nginx/nginx/blob/master/src/core/ngx_connection.h + ffi.cdef [[ + typedef struct { + ngx_str_t src_addr; + ngx_str_t dst_addr; + in_port_t src_port; + in_port_t dst_port; + } ngx_proxy_protocol_t; + + typedef struct { + void *data; + void *read; + void *write; + + int fd; + + ngx_recv_pt recv; + ngx_send_pt send; + ngx_recv_chain_pt recv_chain; + ngx_send_chain_pt send_chain; + + void *listening; + + off_t sent; + + void *log; + + void *pool; + + int type; + + void *sockaddr; + socklen_t socklen; + ngx_str_t addr_text; + + // https://github.com/nginx/nginx/commit/be932e81a1531a3ba032febad968fc2006c4fa48 + ngx_proxy_protocol_t *proxy_protocol; + + ngx_ssl_connection_s *ssl; + // trimmed + } ngx_connection_s; + ]] + else + error("resty.openssl.auxiliary.nginx doesn't support Nginx version " .. ngx_version, 2) + end + + ffi.cdef [[ + typedef struct { + ngx_connection_s *connection; + // trimmed + } ngx_stream_lua_request_s; + + typedef struct { + unsigned int signature; /* "HTTP" */ + + ngx_connection_s *connection; + // trimmed + } ngx_http_request_s; + ]] + + local get_request + do + local ok, exdata = pcall(require, "thread.exdata") + if ok and exdata then + function get_request() + local r = exdata() + if r ~= nil then + return r + end + end + + else + local getfenv = getfenv + + function get_request() + return getfenv(0).__ngx_req + end + end + end + + local SOCKET_CTX_INDEX = 1 + + local NO_C_MODULE_WARNING_MSG_SHOWN = false + local NO_C_MODULE_WARNING_MSG = "note resty.openssl.auxiliary.nginx is using plain FFI " .. + "and it's only intended to be used in development, " .. + "consider using lua-resty-openssl.aux-module in production." + + local function get_ngx_ssl_from_req() + if not NO_C_MODULE_WARNING_MSG_SHOWN then + ngx.log(ngx.WARN, NO_C_MODULE_WARNING_MSG) + NO_C_MODULE_WARNING_MSG_SHOWN = true + end + + local c = get_request() + if ngx.config.subsystem == "stream" then + c = ffi.cast("ngx_stream_lua_request_s*", c) + else -- http + c = ffi.cast("ngx_http_request_s*", c) + end + + local ngx_ssl = c.connection.ssl + if ngx_ssl == nil then + return nil, "c.connection.ssl is nil" + end + return ngx_ssl + end + + get_req_ssl = function() + local ssl, err = get_ngx_ssl_from_req() + if err then + return nil, err + end + + return ssl.connection + end + + get_req_ssl_ctx = function() + local ssl, err = get_ngx_ssl_from_req() + if err then + return nil, err + end + + return ssl.session_ctx + end + + ffi.cdef[[ + typedef struct ngx_http_lua_socket_tcp_upstream_s + ngx_http_lua_socket_tcp_upstream_t; + + typedef struct { + ngx_connection_s *connection; + // trimmed + } ngx_peer_connection_s; + + typedef + int (*ngx_http_lua_socket_tcp_retval_handler_masked)(void *r, + void *u, void *L); + + typedef void (*ngx_http_lua_socket_tcp_upstream_handler_pt_masked) + (void *r, void *u); + + + typedef + int (*ngx_stream_lua_socket_tcp_retval_handler)(void *r, + void *u, void *L); + + typedef void (*ngx_stream_lua_socket_tcp_upstream_handler_pt) + (void *r, void *u); + + typedef struct { + ngx_stream_lua_socket_tcp_retval_handler read_prepare_retvals; + ngx_stream_lua_socket_tcp_retval_handler write_prepare_retvals; + ngx_stream_lua_socket_tcp_upstream_handler_pt read_event_handler; + ngx_stream_lua_socket_tcp_upstream_handler_pt write_event_handler; + + void *socket_pool; + + void *conf; + void *cleanup; + void *request; + + ngx_peer_connection_s peer; + // trimmed + } ngx_stream_lua_socket_tcp_upstream_s; + ]] + + local ngx_lua_version = ngx.config and + ngx.config.ngx_lua_version and + ngx.config.ngx_lua_version + + if ngx_lua_version >= 10019 and ngx_lua_version <= 10021 then + -- https://github.com/openresty/lua-nginx-module/blob/master/src/ngx_http_lua_socket_tcp.h + ffi.cdef[[ + typedef struct { + ngx_http_lua_socket_tcp_retval_handler_masked read_prepare_retvals; + ngx_http_lua_socket_tcp_retval_handler_masked write_prepare_retvals; + ngx_http_lua_socket_tcp_upstream_handler_pt_masked read_event_handler; + ngx_http_lua_socket_tcp_upstream_handler_pt_masked write_event_handler; + + void *udata_queue; // 0.10.19 + + void *socket_pool; + + void *conf; + void *cleanup; + void *request; + ngx_peer_connection_s peer; + // trimmed + } ngx_http_lua_socket_tcp_upstream_s; + ]] + elseif ngx_lua_version < 10019 then + -- the struct doesn't seem to get changed a long time since birth + ffi.cdef[[ + typedef struct { + ngx_http_lua_socket_tcp_retval_handler_masked read_prepare_retvals; + ngx_http_lua_socket_tcp_retval_handler_masked write_prepare_retvals; + ngx_http_lua_socket_tcp_upstream_handler_pt_masked read_event_handler; + ngx_http_lua_socket_tcp_upstream_handler_pt_masked write_event_handler; + + void *socket_pool; + + void *conf; + void *cleanup; + void *request; + ngx_peer_connection_s peer; + // trimmed + } ngx_http_lua_socket_tcp_upstream_s; + ]] + else + error("resty.openssl.auxiliary.nginx doesn't support lua-nginx-module version " .. (ngx_lua_version or "nil"), 2) + end + + local function get_ngx_ssl_from_socket_ctx(sock) + if not NO_C_MODULE_WARNING_MSG_SHOWN then + ngx.log(ngx.WARN, NO_C_MODULE_WARNING_MSG) + NO_C_MODULE_WARNING_MSG_SHOWN = true + end + + local u = sock[SOCKET_CTX_INDEX] + if u == nil then + return nil, "lua_socket_tcp_upstream_t not found" + end + + if ngx.config.subsystem == "stream" then + u = ffi.cast("ngx_stream_lua_socket_tcp_upstream_s*", u) + else -- http + u = ffi.cast("ngx_http_lua_socket_tcp_upstream_s*", u) + end + + local p = u.peer + if p == nil then + return nil, "u.peer is nil" + end + + local uc = p.connection + if uc == nil then + return nil, "u.peer.connection is nil" + end + + local ngx_ssl = uc.ssl + if ngx_ssl == nil then + return nil, "u.peer.connection.ssl is nil" + end + return ngx_ssl + end + + get_socket_ssl = function(sock) + local ssl, err = get_ngx_ssl_from_socket_ctx(sock) + if err then + return nil, err + end + + return ssl.connection + end + + get_socket_ssl_ctx = function(sock) + local ssl, err = get_ngx_ssl_from_socket_ctx(sock) + if err then + return nil, err + end + + return ssl.session_ctx + end + +end + + +return { + get_req_ssl = get_req_ssl, + get_req_ssl_ctx = get_req_ssl_ctx, + get_socket_ssl = get_socket_ssl, + get_socket_ssl_ctx = get_socket_ssl_ctx, +} diff --git a/server/resty/openssl/auxiliary/nginx_c.lua b/server/resty/openssl/auxiliary/nginx_c.lua new file mode 100644 index 0000000..f50db36 --- /dev/null +++ b/server/resty/openssl/auxiliary/nginx_c.lua @@ -0,0 +1,154 @@ +local ffi = require "ffi" +local C = ffi.C + +local SOCKET_CTX_INDEX = 1 +local NGX_OK = ngx.OK + + +local get_req_ssl, get_req_ssl_ctx +local get_socket_ssl, get_socket_ssl_ctx + +local get_request +do + local ok, exdata = pcall(require, "thread.exdata") + if ok and exdata then + function get_request() + local r = exdata() + if r ~= nil then + return r + end + end + + else + local getfenv = getfenv + + function get_request() + return getfenv(0).__ngx_req + end + end +end + + +local stream_subsystem = false +if ngx.config.subsystem == "stream" then + stream_subsystem = true + + ffi.cdef [[ + typedef struct ngx_stream_lua_request_s ngx_stream_lua_request_t; + typedef struct ngx_stream_lua_socket_tcp_upstream_s ngx_stream_lua_socket_tcp_upstream_t; + + int ngx_stream_lua_resty_openssl_aux_get_request_ssl(ngx_stream_lua_request_t *r, + void **_ssl_conn); + + int ngx_stream_lua_resty_openssl_aux_get_request_ssl_ctx(ngx_stream_lua_request_t *r, + void **_sess); + + int ngx_stream_lua_resty_openssl_aux_get_socket_ssl(ngx_stream_lua_socket_tcp_upstream_t *u, + void **_ssl_conn); + + int ngx_stream_lua_resty_openssl_aux_get_socket_ssl_ctx(ngx_stream_lua_socket_tcp_upstream_t *u, + void **_sess); + ]] + + -- sanity test + local _ = C.ngx_stream_lua_resty_openssl_aux_get_request_ssl +else + ffi.cdef [[ + typedef struct ngx_http_request_s ngx_http_request_t; + typedef struct ngx_http_lua_socket_tcp_upstream_s ngx_http_lua_socket_tcp_upstream_t; + + int ngx_http_lua_resty_openssl_aux_get_request_ssl(ngx_http_request_t *r, + void **_ssl_conn); + + int ngx_http_lua_resty_openssl_aux_get_request_ssl_ctx(ngx_http_request_t *r, + void **_sess); + + int ngx_http_lua_resty_openssl_aux_get_socket_ssl(ngx_http_lua_socket_tcp_upstream_t *u, + void **_ssl_conn); + + int ngx_http_lua_resty_openssl_aux_get_socket_ssl_ctx(ngx_http_lua_socket_tcp_upstream_t *u, + void **_sess); + ]] + + -- sanity test + local _ = C.ngx_http_lua_resty_openssl_aux_get_request_ssl +end + +local void_pp = ffi.new("void *[1]") +local ssl_type = ffi.typeof("SSL*") +local ssl_ctx_type = ffi.typeof("SSL_CTX*") + +get_req_ssl = function() + local c = get_request() + + local ret + if stream_subsystem then + ret = C.ngx_stream_lua_resty_openssl_aux_get_request_ssl(c, void_pp) + else + ret = C.ngx_http_lua_resty_openssl_aux_get_request_ssl(c, void_pp) + end + + if ret ~= NGX_OK then + return nil, "cannot read r->connection->ssl->connection" + end + + return ffi.cast(ssl_type, void_pp[0]) +end + +get_req_ssl_ctx = function() + local c = get_request() + + local ret + if stream_subsystem then + ret = C.ngx_stream_lua_resty_openssl_aux_get_request_ssl_ctx(c, void_pp) + else + ret = C.ngx_http_lua_resty_openssl_aux_get_request_ssl_ctx(c, void_pp) + end + + if ret ~= NGX_OK then + return nil, "cannot read r->connection->ssl->session_ctx" + end + + return ffi.cast(ssl_ctx_type, void_pp[0]) +end + +get_socket_ssl = function(sock) + local u = sock[SOCKET_CTX_INDEX] + + local ret + if stream_subsystem then + ret = C.ngx_stream_lua_resty_openssl_aux_get_socket_ssl(u, void_pp) + else + ret = C.ngx_http_lua_resty_openssl_aux_get_socket_ssl(u, void_pp) + end + + if ret ~= NGX_OK then + return nil, "cannot read u->peer.connection->ssl->connection" + end + + return ffi.cast(ssl_type, void_pp[0]) +end + +get_socket_ssl_ctx = function(sock) + local u = sock[SOCKET_CTX_INDEX] + + local ret + if stream_subsystem then + ret = C.ngx_stream_lua_resty_openssl_aux_get_socket_ssl_ctx(u, void_pp) + else + ret = C.ngx_http_lua_resty_openssl_aux_get_socket_ssl_ctx(u, void_pp) + end + + if ret ~= NGX_OK then + return nil, "cannot read u->peer.connection->ssl->session_ctx" + end + + return ffi.cast(ssl_ctx_type, void_pp[0]) +end + +return { + get_req_ssl = get_req_ssl, + get_req_ssl_ctx = get_req_ssl_ctx, + get_socket_ssl = get_socket_ssl, + get_socket_ssl_ctx = get_socket_ssl_ctx, +}
\ No newline at end of file diff --git a/server/resty/openssl/bn.lua b/server/resty/openssl/bn.lua new file mode 100644 index 0000000..e893e5e --- /dev/null +++ b/server/resty/openssl/bn.lua @@ -0,0 +1,416 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc +local ffi_new = ffi.new +local ffi_str = ffi.string +local floor = math.floor + +require "resty.openssl.include.bn" +local crypto_macro = require("resty.openssl.include.crypto") +local ctypes = require "resty.openssl.auxiliary.ctypes" +local format_error = require("resty.openssl.err").format_error +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X + +local _M = {} +local mt = {__index = _M} + +local bn_ptr_ct = ffi.typeof('BIGNUM*') +local bn_ptrptr_ct = ffi.typeof('BIGNUM*[1]') + +function _M.new(bn) + local ctx = C.BN_new() + ffi_gc(ctx, C.BN_free) + + if type(bn) == 'number' then + if C.BN_set_word(ctx, bn) ~= 1 then + return nil, format_error("bn.new") + end + elseif bn then + return nil, "bn.new: expect nil or a number at #1" + end + + return setmetatable( { ctx = ctx }, mt), nil +end + +function _M.istype(l) + return l and l.ctx and ffi.istype(bn_ptr_ct, l.ctx) +end + +function _M.dup(ctx) + if not ffi.istype(bn_ptr_ct, ctx) then + return nil, "bn.dup: expect a bn ctx at #1" + end + local ctx = C.BN_dup(ctx) + ffi_gc(ctx, C.BN_free) + + local self = setmetatable({ + ctx = ctx, + }, mt) + + return self +end + +function _M:to_binary() + local length = (C.BN_num_bits(self.ctx)+7)/8 + -- align to bytes + length = floor(length) + local buf = ctypes.uchar_array(length) + local sz = C.BN_bn2bin(self.ctx, buf) + if sz == 0 then + return nil, format_error("bn:to_binary") + end + buf = ffi_str(buf, length) + return buf +end + +function _M.from_binary(s) + if type(s) ~= "string" then + return nil, "bn.from_binary: expect a string at #1" + end + + local ctx = C.BN_bin2bn(s, #s, nil) + if ctx == nil then + return nil, format_error("bn.from_binary") + end + ffi_gc(ctx, C.BN_free) + return setmetatable( { ctx = ctx }, mt), nil +end + +function _M:to_hex() + local buf = C.BN_bn2hex(self.ctx) + if buf == nil then + return nil, format_error("bn:to_hex") + end + ffi_gc(buf, crypto_macro.OPENSSL_free) + local s = ffi_str(buf) + return s +end + +function _M.from_hex(s) + if type(s) ~= "string" then + return nil, "bn.from_hex: expect a string at #1" + end + + local p = ffi_new(bn_ptrptr_ct) + + if C.BN_hex2bn(p, s) == 0 then + return nil, format_error("bn.from_hex") + end + local ctx = p[0] + ffi_gc(ctx, C.BN_free) + return setmetatable( { ctx = ctx }, mt), nil +end + +function _M:to_dec() + local buf = C.BN_bn2dec(self.ctx) + if buf == nil then + return nil, format_error("bn:to_dec") + end + ffi_gc(buf, crypto_macro.OPENSSL_free) + local s = ffi_str(buf) + return s +end +mt.__tostring = _M.to_dec + +function _M.from_dec(s) + if type(s) ~= "string" then + return nil, "bn.from_dec: expect a string at #1" + end + + local p = ffi_new(bn_ptrptr_ct) + + if C.BN_dec2bn(p, s) == 0 then + return nil, format_error("bn.from_dec") + end + local ctx = p[0] + ffi_gc(ctx, C.BN_free) + return setmetatable( { ctx = ctx }, mt), nil +end + +function _M:to_number() + return tonumber(C.BN_get_word(self.ctx)) +end +_M.tonumber = _M.to_number + +function _M.generate_prime(bits, safe) + local ctx = C.BN_new() + ffi_gc(ctx, C.BN_free) + + if C.BN_generate_prime_ex(ctx, bits, safe and 1 or 0, nil, nil, nil) == 0 then + return nil, format_error("bn.BN_generate_prime_ex") + end + + return setmetatable( { ctx = ctx }, mt), nil +end + +-- BN_CTX is used to store temporary variable +-- we only need one per worker +local bn_ctx_tmp = C.BN_CTX_new() +assert(bn_ctx_tmp ~= nil) +if OPENSSL_10 then + C.BN_CTX_init(bn_ctx_tmp) +end +ffi_gc(bn_ctx_tmp, C.BN_CTX_free) + +_M.bn_ctx_tmp = bn_ctx_tmp + +-- mathematics + +local is_negative +if OPENSSL_10 then + local bn_zero = assert(_M.new(0)).ctx + is_negative = function(ctx) + return C.BN_cmp(ctx, bn_zero) < 0 and 1 or 0 + end +else + is_negative = C.BN_is_negative +end +function mt.__unm(a) + local b = _M.dup(a.ctx) + if b == nil then + error("BN_dup() failed") + end + local sign = is_negative(b.ctx) + C.BN_set_negative(b.ctx, 1-sign) + return b +end + +local function check_args(op, ...) + local args = {...} + for i, arg in ipairs(args) do + if type(arg) == 'number' then + local b = C.BN_new() + if b == nil then + error("BN_new() failed") + end + ffi_gc(b, C.BN_free) + if C.BN_set_word(b, arg) ~= 1 then + error("BN_set_word() failed") + end + args[i] = b + elseif _M.istype(arg) then + args[i] = arg.ctx + else + error("cannot " .. op .. " a " .. type(arg) .. " to bignum") + end + end + local ctx = C.BN_new() + if ctx == nil then + error("BN_new() failed") + end + ffi_gc(ctx, C.BN_free) + local r = setmetatable( { ctx = ctx }, mt) + return r, unpack(args) +end + + +function mt.__add(...) + local r, a, b = check_args("add", ...) + if C.BN_add(r.ctx, a, b) == 0 then + error("BN_add() failed") + end + return r +end +_M.add = mt.__add + +function mt.__sub(...) + local r, a, b = check_args("substract", ...) + if C.BN_sub(r.ctx, a, b) == 0 then + error("BN_sub() failed") + end + return r +end +_M.sub = mt.__sub + +function mt.__mul(...) + local r, a, b = check_args("multiply", ...) + if C.BN_mul(r.ctx, a, b, bn_ctx_tmp) == 0 then + error("BN_mul() failed") + end + return r +end +_M.mul = mt.__mul + +-- lua 5.3 only +function mt.__idiv(...) + local r, a, b = check_args("divide", ...) + if C.BN_div(r.ctx, nil, a, b, bn_ctx_tmp) == 0 then + error("BN_div() failed") + end + return r +end + +mt.__div = mt.__idiv +_M.idiv = mt.__idiv +_M.div = mt.__div + +function mt.__mod(...) + local r, a, b = check_args("mod", ...) + if C.BN_div(nil, r.ctx, a, b, bn_ctx_tmp) == 0 then + error("BN_div() failed") + end + return r +end +_M.mod = mt.__mod + +-- __concat doesn't make sense at all? + +function _M.sqr(...) + local r, a = check_args("square", ...) + if C.BN_sqr(r.ctx, a, bn_ctx_tmp) == 0 then + error("BN_sqr() failed") + end + return r +end + +function _M.gcd(...) + local r, a, b = check_args("extract greatest common divisor", ...) + if C.BN_gcd(r.ctx, a, b, bn_ctx_tmp) == 0 then + error("BN_gcd() failed") + end + return r +end + +function _M.exp(...) + local r, a, b = check_args("power", ...) + if C.BN_exp(r.ctx, a, b, bn_ctx_tmp) == 0 then + error("BN_exp() failed") + end + return r +end +_M.pow = _M.exp + +for _, op in ipairs({ "add", "sub" , "mul", "exp" }) do + local f = "BN_mod_" .. op + local cf = C[f] + _M["mod_" .. op] = function(...) + local r, a, b, m = check_args(op, ...) + if cf(r.ctx, a, b, m, bn_ctx_tmp) == 0 then + error(f .. " failed") + end + return r + end +end + +function _M.mod_sqr(...) + local r, a, m = check_args("mod_sub", ...) + if C.BN_mod_sqr(r.ctx, a, m, bn_ctx_tmp) == 0 then + error("BN_mod_sqr() failed") + end + return r +end + +local function nyi() + error("NYI") +end + +-- bit operations, lua 5.3 + +mt.__band = nyi +mt.__bor = nyi +mt.__bxor = nyi +mt.__bnot = nyi + +function mt.__shl(a, b) + local r, a = check_args("lshift", a) + if C.BN_lshift(r.ctx, a, b) == 0 then + error("BN_lshift() failed") + end + return r +end +_M.lshift = mt.__shl + +function mt.__shr(a, b) + local r, a = check_args("rshift", a) + if C.BN_rshift(r.ctx, a, b) == 0 then + error("BN_lshift() failed") + end + return r +end +_M.rshift = mt.__shr + +-- comparaions +-- those functions are only called when the table +-- has exact same metamethods, i.e. they are all BN +-- so we don't need to check args + +function mt.__eq(a, b) + return C.BN_cmp(a.ctx, b.ctx) == 0 +end + +function mt.__lt(a, b) + return C.BN_cmp(a.ctx, b.ctx) < 0 +end + +function mt.__le(a, b) + return C.BN_cmp(a.ctx, b.ctx) <= 0 +end + +if OPENSSL_10 then + -- in openssl 1.0.x those functions are implemented as macros + -- don't want to copy paste all structs here + -- the followings are definitely slower, but works + local bn_zero = assert(_M.new(0)).ctx + local bn_one = assert(_M.new(1)).ctx + + function _M:is_zero() + return C.BN_cmp(self.ctx, bn_zero) == 0 + end + + function _M:is_one() + return C.BN_cmp(self.ctx, bn_one) == 0 + end + + function _M:is_word(n) + local ctx = C.BN_new() + ffi_gc(ctx, C.BN_free) + if ctx == nil then + return nil, "bn:is_word: BN_new() failed" + end + if C.BN_set_word(ctx, n) ~= 1 then + return nil, "bn:is_word: BN_set_word() failed" + end + return C.BN_cmp(self.ctx, ctx) == 0 + end + + function _M:is_odd() + return self:to_number() % 2 == 1 + end +else + function _M:is_zero() + return C.BN_is_zero(self.ctx) == 1 + end + + function _M:is_one() + return C.BN_is_one(self.ctx) == 1 + end + + function _M:is_word(n) + return C.BN_is_word(self.ctx, n) == 1 + end + + function _M:is_odd() + return C.BN_is_odd(self.ctx) == 1 + end +end + +function _M:is_prime(nchecks) + if nchecks and type(nchecks) ~= "number" then + return nil, "bn:is_prime: expect a number at #1" + end + -- if nchecks is not defined, set to BN_prime_checks: + -- select number of iterations based on the size of the number + local code + if OPENSSL_3X then + code = C.BN_check_prime(self.ctx, bn_ctx_tmp, nil) + else + code = C.BN_is_prime_ex(self.ctx, nchecks or 0, bn_ctx_tmp, nil) + end + if code == -1 then + return nil, format_error("bn.is_prime") + end + return code == 1 +end + +return _M diff --git a/server/resty/openssl/cipher.lua b/server/resty/openssl/cipher.lua new file mode 100644 index 0000000..693ac09 --- /dev/null +++ b/server/resty/openssl/cipher.lua @@ -0,0 +1,300 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc +local ffi_str = ffi.string +local ffi_cast = ffi.cast + +require "resty.openssl.include.evp.cipher" +local evp_macro = require "resty.openssl.include.evp" +local ctypes = require "resty.openssl.auxiliary.ctypes" +local ctx_lib = require "resty.openssl.ctx" +local format_error = require("resty.openssl.err").format_error +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_11_OR_LATER = require("resty.openssl.version").OPENSSL_11_OR_LATER +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X + +local uchar_array = ctypes.uchar_array +local void_ptr = ctypes.void_ptr +local ptr_of_int = ctypes.ptr_of_int +local uchar_ptr = ctypes.uchar_ptr + +local _M = {} +local mt = {__index = _M} + +local cipher_ctx_ptr_ct = ffi.typeof('EVP_CIPHER_CTX*') + +local out_length = ptr_of_int() +-- EVP_MAX_BLOCK_LENGTH is 32, we give it a 64 to be future proof +local out_buffer = ctypes.uchar_array(1024 + 64) + +function _M.new(typ, properties) + if not typ then + return nil, "cipher.new: expect type to be defined" + end + + local ctx + if OPENSSL_11_OR_LATER then + ctx = C.EVP_CIPHER_CTX_new() + ffi_gc(ctx, C.EVP_CIPHER_CTX_free) + elseif OPENSSL_10 then + ctx = ffi.new('EVP_CIPHER_CTX') + C.EVP_CIPHER_CTX_init(ctx) + ffi_gc(ctx, C.EVP_CIPHER_CTX_cleanup) + end + if ctx == nil then + return nil, "cipher.new: failed to create EVP_CIPHER_CTX" + end + + local ctyp + if OPENSSL_3X then + ctyp = C.EVP_CIPHER_fetch(ctx_lib.get_libctx(), typ, properties) + else + ctyp = C.EVP_get_cipherbyname(typ) + end + local err_new = string.format("cipher.new: invalid cipher type \"%s\"", typ) + if ctyp == nil then + return nil, format_error(err_new) + end + + local code = C.EVP_CipherInit_ex(ctx, ctyp, nil, "", nil, -1) + if code ~= 1 then + return nil, format_error(err_new) + end + + return setmetatable({ + ctx = ctx, + algo = ctyp, + initialized = false, + block_size = tonumber(OPENSSL_3X and C.EVP_CIPHER_CTX_get_block_size(ctx) + or C.EVP_CIPHER_CTX_block_size(ctx)), + key_size = tonumber(OPENSSL_3X and C.EVP_CIPHER_CTX_get_key_length(ctx) + or C.EVP_CIPHER_CTX_key_length(ctx)), + iv_size = tonumber(OPENSSL_3X and C.EVP_CIPHER_CTX_get_iv_length(ctx) + or C.EVP_CIPHER_CTX_iv_length(ctx)), + }, mt), nil +end + +function _M.istype(l) + return l and l.ctx and ffi.istype(cipher_ctx_ptr_ct, l.ctx) +end + +function _M:get_provider_name() + if not OPENSSL_3X then + return false, "cipher:get_provider_name is not supported" + end + local p = C.EVP_CIPHER_get0_provider(self.algo) + if p == nil then + return nil + end + return ffi_str(C.OSSL_PROVIDER_get0_name(p)) +end + +if OPENSSL_3X then + local param_lib = require "resty.openssl.param" + _M.settable_params, _M.set_params, _M.gettable_params, _M.get_param = param_lib.get_params_func("EVP_CIPHER_CTX") +end + +function _M:init(key, iv, opts) + opts = opts or {} + if not key or #key ~= self.key_size then + return false, string.format("cipher:init: incorrect key size, expect %d", self.key_size) + end + if not iv or #iv ~= self.iv_size then + return false, string.format("cipher:init: incorrect iv size, expect %d", self.iv_size) + end + + -- always passed in the `EVP_CIPHER` parameter to reinitialized the cipher + -- it will have a same effect as EVP_CIPHER_CTX_cleanup/EVP_CIPHER_CTX_reset then Init_ex with + -- empty algo + if C.EVP_CipherInit_ex(self.ctx, self.algo, nil, key, iv, opts.is_encrypt and 1 or 0) == 0 then + return false, format_error("cipher:init EVP_CipherInit_ex") + end + + if opts.no_padding then + -- EVP_CIPHER_CTX_set_padding() always returns 1. + C.EVP_CIPHER_CTX_set_padding(self.ctx, 0) + end + + self.initialized = true + + return true +end + +function _M:encrypt(key, iv, s, no_padding, aead_aad) + local _, err = self:init(key, iv, { + is_encrypt = true, + no_padding = no_padding, + }) + if err then + return nil, err + end + if aead_aad then + local _, err = self:update_aead_aad(aead_aad) + if err then + return nil, err + end + end + return self:final(s) +end + +function _M:decrypt(key, iv, s, no_padding, aead_aad, aead_tag) + local _, err = self:init(key, iv, { + is_encrypt = false, + no_padding = no_padding, + }) + if err then + return nil, err + end + if aead_aad then + local _, err = self:update_aead_aad(aead_aad) + if err then + return nil, err + end + end + if aead_tag then + local _, err = self:set_aead_tag(aead_tag) + if err then + return nil, err + end + end + return self:final(s) +end + +-- https://wiki.openssl.org/index.php/EVP_Authenticated_Encryption_and_Decryption +function _M:update_aead_aad(aad) + if not self.initialized then + return nil, "cipher:update_aead_aad: cipher not initalized, call cipher:init first" + end + + if C.EVP_CipherUpdate(self.ctx, nil, out_length, aad, #aad) ~= 1 then + return false, format_error("cipher:update_aead_aad") + end + return true +end + +function _M:get_aead_tag(size) + if not self.initialized then + return nil, "cipher:get_aead_tag: cipher not initalized, call cipher:init first" + end + + size = size or self.key_size / 2 + if size > self.key_size then + return nil, string.format("tag size %d is too large", size) + end + if C.EVP_CIPHER_CTX_ctrl(self.ctx, evp_macro.EVP_CTRL_AEAD_GET_TAG, size, out_buffer) ~= 1 then + return nil, format_error("cipher:get_aead_tag") + end + + return ffi_str(out_buffer, size) +end + +function _M:set_aead_tag(tag) + if not self.initialized then + return nil, "cipher:set_aead_tag: cipher not initalized, call cipher:init first" + end + + if type(tag) ~= "string" then + return false, "cipher:set_aead_tag expect a string at #1" + end + local tag_void_ptr = ffi_cast(void_ptr, tag) + if C.EVP_CIPHER_CTX_ctrl(self.ctx, evp_macro.EVP_CTRL_AEAD_SET_TAG, #tag, tag_void_ptr) ~= 1 then + return false, format_error("cipher:set_aead_tag") + end + + return true +end + +function _M:update(...) + if not self.initialized then + return nil, "cipher:update: cipher not initalized, call cipher:init first" + end + + local ret = {} + for i, s in ipairs({...}) do + local inl = #s + if inl > 1024 then + s = ffi_cast(uchar_ptr, s) + for i=0, inl-1, 1024 do + local chunk_size = 1024 + if inl - i < 1024 then + chunk_size = inl - i + end + if C.EVP_CipherUpdate(self.ctx, out_buffer, out_length, s+i, chunk_size) ~= 1 then + return nil, format_error("cipher:update") + end + table.insert(ret, ffi_str(out_buffer, out_length[0])) + end + else + if C.EVP_CipherUpdate(self.ctx, out_buffer, out_length, s, inl) ~= 1 then + return nil, format_error("cipher:update") + end + table.insert(ret, ffi_str(out_buffer, out_length[0])) + end + end + return table.concat(ret, "") +end + +function _M:final(s) + local ret, err + if s then + ret, err = self:update(s) + if err then + return nil, err + end + end + if C.EVP_CipherFinal_ex(self.ctx, out_buffer, out_length) ~= 1 then + return nil, format_error("cipher:final: EVP_CipherFinal_ex") + end + local final_ret = ffi_str(out_buffer, out_length[0]) + return ret and (ret .. final_ret) or final_ret +end + + +function _M:derive(key, salt, count, md, md_properties) + if type(key) ~= "string" then + return nil, nil, "cipher:derive: expect a string at #1" + elseif salt and type(salt) ~= "string" then + return nil, nil, "cipher:derive: expect a string at #2" + elseif count then + count = tonumber(count) + if not count then + return nil, nil, "cipher:derive: expect a number at #3" + end + elseif md and type(md) ~= "string" then + return nil, nil, "cipher:derive: expect a string or nil at #4" + end + + if salt then + if #salt > 8 then + ngx.log(ngx.WARN, "cipher:derive: salt is too long, truncate salt to 8 bytes") + salt = salt:sub(0, 8) + elseif #salt < 8 then + ngx.log(ngx.WARN, "cipher:derive: salt is too short, padding with zero bytes to length") + salt = salt .. string.rep('\000', 8 - #salt) + end + end + + local mdt + if OPENSSL_3X then + mdt = C.EVP_MD_fetch(ctx_lib.get_libctx(), md or 'sha1', md_properties) + else + mdt = C.EVP_get_digestbyname(md or 'sha1') + end + if mdt == nil then + return nil, nil, string.format("cipher:derive: invalid digest type \"%s\"", md) + end + local cipt = C.EVP_CIPHER_CTX_cipher(self.ctx) + local keyb = uchar_array(self.key_size) + local ivb = uchar_array(self.iv_size) + + local size = C.EVP_BytesToKey(cipt, mdt, salt, + key, #key, count or 1, + keyb, ivb) + if size == 0 then + return nil, nil, format_error("cipher:derive: EVP_BytesToKey") + end + + return ffi_str(keyb, size), ffi_str(ivb, self.iv_size) +end + +return _M diff --git a/server/resty/openssl/ctx.lua b/server/resty/openssl/ctx.lua new file mode 100644 index 0000000..eaec396 --- /dev/null +++ b/server/resty/openssl/ctx.lua @@ -0,0 +1,78 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc + +require "resty.openssl.include.ossl_typ" +local format_error = require("resty.openssl.err").format_error +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X + +ffi.cdef [[ + OSSL_LIB_CTX *OSSL_LIB_CTX_new(void); + int OSSL_LIB_CTX_load_config(OSSL_LIB_CTX *ctx, const char *config_file); + void OSSL_LIB_CTX_free(OSSL_LIB_CTX *ctx); +]] + +local ossl_lib_ctx + +local function new(request_context_only, conf_file) + if not OPENSSL_3X then + return false, "ctx is only supported from OpenSSL 3.0" + end + + local ctx = C.OSSL_LIB_CTX_new() + ffi_gc(ctx, C.OSSL_LIB_CTX_free) + + if conf_file and C.OSSL_LIB_CTX_load_config(ctx, conf_file) ~= 1 then + return false, format_error("ctx.new") + end + + if request_context_only then + ngx.ctx.ossl_lib_ctx = ctx + else + ossl_lib_ctx = ctx + end + + return true +end + +local function free(request_context_only) + if not OPENSSL_3X then + return false, "ctx is only supported from OpenSSL 3.0" + end + + if request_context_only then + ngx.ctx.ossl_lib_ctx = nil + else + ossl_lib_ctx = nil + end + + return true +end + +local test_request + +do + + local ok, exdata = pcall(require, "thread.exdata") + if ok and exdata then + test_request = function() + local r = exdata() + if r ~= nil then + return not not r + end + end + + else + local getfenv = getfenv + + function test_request() + return not not getfenv(0).__ngx_req + end + end +end + +return { + new = new, + free = free, + get_libctx = function() return test_request() and ngx.ctx.ossl_lib_ctx or ossl_lib_ctx end, +}
\ No newline at end of file diff --git a/server/resty/openssl/dh.lua b/server/resty/openssl/dh.lua new file mode 100644 index 0000000..93e4941 --- /dev/null +++ b/server/resty/openssl/dh.lua @@ -0,0 +1,142 @@ +local ffi = require "ffi" +local C = ffi.C + +require "resty.openssl.include.dh" +local bn_lib = require "resty.openssl.bn" + +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_11_OR_LATER = require("resty.openssl.version").OPENSSL_11_OR_LATER +local format_error = require("resty.openssl.err").format_error + +local _M = {} + +_M.params = {"public", "private", "p", "q", "g"} + +local empty_table = {} +local bn_ptrptr_ct = ffi.typeof("const BIGNUM *[1]") +function _M.get_parameters(dh_st) + return setmetatable(empty_table, { + __index = function(_, k) + local ptr, ret + if OPENSSL_11_OR_LATER then + ptr = bn_ptrptr_ct() + end + + if OPENSSL_11_OR_LATER then + ptr = bn_ptrptr_ct() + end + + if k == 'p' then + if OPENSSL_11_OR_LATER then + C.DH_get0_pqg(dh_st, ptr, nil, nil) + end + elseif k == 'q' then + if OPENSSL_11_OR_LATER then + C.DH_get0_pqg(dh_st, nil, ptr, nil) + end + elseif k == 'g' then + if OPENSSL_11_OR_LATER then + C.DH_get0_pqg(dh_st, nil, nil, ptr) + end + elseif k == 'public' then + if OPENSSL_11_OR_LATER then + C.DH_get0_key(dh_st, ptr, nil) + end + k = "pub_key" + elseif k == 'private' then + if OPENSSL_11_OR_LATER then + C.DH_get0_key(dh_st, nil, ptr) + end + k = "priv_key" + else + return nil, "rsa.get_parameters: unknown parameter \"" .. k .. "\" for RSA key" + end + + if OPENSSL_11_OR_LATER then + ret = ptr[0] + elseif OPENSSL_10 then + ret = dh_st[k] + end + + if ret == nil then + return nil + end + return bn_lib.dup(ret) + end + }), nil +end + +local function dup_bn_value(v) + if not bn_lib.istype(v) then + return nil, "expect value to be a bn instance" + end + local bn = C.BN_dup(v.ctx) + if bn == nil then + return nil, "BN_dup() failed" + end + return bn +end + +function _M.set_parameters(dh_st, opts) + local err + local opts_bn = {} + -- remember which parts of BNs has been added to dh_st, they should be freed + -- by DH_free and we don't cleanup them on failure + local cleanup_from_idx = 1 + -- dup input + local do_set_key, do_set_pqg + for k, v in pairs(opts) do + opts_bn[k], err = dup_bn_value(v) + if err then + err = "dh.set_parameters: cannot process parameter \"" .. k .. "\":" .. err + goto cleanup_with_error + end + if k == "private" or k == "public" then + do_set_key = true + elseif k == "p" or k == "q" or k == "g" then + do_set_pqg = true + end + end + if OPENSSL_11_OR_LATER then + local code + if do_set_key then + code = C.DH_set0_key(dh_st, opts_bn["public"], opts_bn["private"]) + if code == 0 then + err = format_error("dh.set_parameters: DH_set0_key") + goto cleanup_with_error + end + end + cleanup_from_idx = cleanup_from_idx + 2 + if do_set_pqg then + code = C.DH_set0_pqg(dh_st, opts_bn["p"], opts_bn["q"], opts_bn["g"]) + if code == 0 then + err = format_error("dh.set_parameters: DH_set0_pqg") + goto cleanup_with_error + end + end + return true + elseif OPENSSL_10 then + for k, v in pairs(opts_bn) do + if k == "public" then + k = "pub_key" + elseif k == "private" then + k = "priv_key" + end + if dh_st[k] ~= nil then + C.BN_free(dh_st[k]) + end + dh_st[k]= v + end + return true + end + +::cleanup_with_error:: + for i, k in pairs(_M.params) do + if i >= cleanup_from_idx then + C.BN_free(opts_bn[k]) + end + end + return false, err +end + +return _M diff --git a/server/resty/openssl/digest.lua b/server/resty/openssl/digest.lua new file mode 100644 index 0000000..cfef9ae --- /dev/null +++ b/server/resty/openssl/digest.lua @@ -0,0 +1,116 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc +local ffi_str = ffi.string + +require "resty.openssl.include.evp.md" +local ctypes = require "resty.openssl.auxiliary.ctypes" +local ctx_lib = require "resty.openssl.ctx" +local format_error = require("resty.openssl.err").format_error +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_11_OR_LATER = require("resty.openssl.version").OPENSSL_11_OR_LATER +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X + +local _M = {} +local mt = {__index = _M} + +local md_ctx_ptr_ct = ffi.typeof('EVP_MD_CTX*') + +function _M.new(typ, properties) + local ctx + if OPENSSL_11_OR_LATER then + ctx = C.EVP_MD_CTX_new() + ffi_gc(ctx, C.EVP_MD_CTX_free) + elseif OPENSSL_10 then + ctx = C.EVP_MD_CTX_create() + ffi_gc(ctx, C.EVP_MD_CTX_destroy) + end + if ctx == nil then + return nil, "digest.new: failed to create EVP_MD_CTX" + end + + local err_new = string.format("digest.new: invalid digest type \"%s\"", typ) + + local algo + if typ == "null" then + algo = C.EVP_md_null() + else + if OPENSSL_3X then + algo = C.EVP_MD_fetch(ctx_lib.get_libctx(), typ or 'sha1', properties) + else + algo = C.EVP_get_digestbyname(typ or 'sha1') + end + if algo == nil then + return nil, format_error(err_new) + end + end + + local code = C.EVP_DigestInit_ex(ctx, algo, nil) + if code ~= 1 then + return nil, format_error(err_new) + end + + return setmetatable({ + ctx = ctx, + algo = algo, + buf = ctypes.uchar_array(OPENSSL_3X and C.EVP_MD_get_size(algo) or C.EVP_MD_size(algo)), + }, mt), nil +end + +function _M.istype(l) + return l and l.ctx and ffi.istype(md_ctx_ptr_ct, l.ctx) +end + +function _M:get_provider_name() + if not OPENSSL_3X then + return false, "digest:get_provider_name is not supported" + end + local p = C.EVP_MD_get0_provider(self.algo) + if p == nil then + return nil + end + return ffi_str(C.OSSL_PROVIDER_get0_name(p)) +end + +if OPENSSL_3X then + local param_lib = require "resty.openssl.param" + _M.settable_params, _M.set_params, _M.gettable_params, _M.get_param = param_lib.get_params_func("EVP_MD_CTX") +end + +function _M:update(...) + for _, s in ipairs({...}) do + if C.EVP_DigestUpdate(self.ctx, s, #s) ~= 1 then + return false, format_error("digest:update") + end + end + return true, nil +end + +local result_length = ctypes.ptr_of_uint() + +function _M:final(s) + if s then + if C.EVP_DigestUpdate(self.ctx, s, #s) ~= 1 then + return false, format_error("digest:final") + end + end + + -- no return value of EVP_DigestFinal_ex + C.EVP_DigestFinal_ex(self.ctx, self.buf, result_length) + if result_length[0] == nil or result_length[0] <= 0 then + return nil, format_error("digest:final: EVP_DigestFinal_ex") + end + return ffi_str(self.buf, result_length[0]) +end + + +function _M:reset() + local code = C.EVP_DigestInit_ex(self.ctx, self.algo, nil) + if code ~= 1 then + return false, format_error("digest:reset") + end + + return true +end + +return _M diff --git a/server/resty/openssl/ec.lua b/server/resty/openssl/ec.lua new file mode 100644 index 0000000..2d0dd02 --- /dev/null +++ b/server/resty/openssl/ec.lua @@ -0,0 +1,186 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc + +require "resty.openssl.include.ec" +local bn_lib = require "resty.openssl.bn" +local objects_lib = require "resty.openssl.objects" +local ctypes = require "resty.openssl.auxiliary.ctypes" + +local version_num = require("resty.openssl.version").version_num +local format_error = require("resty.openssl.err").format_error +local BORINGSSL = require("resty.openssl.version").BORINGSSL + +local _M = {} + +_M.params = {"group", "public", "private", "x", "y"} + +local empty_table = {} + +function _M.get_parameters(ec_key_st) + return setmetatable(empty_table, { + __index = function(_, k) + local group = C.EC_KEY_get0_group(ec_key_st) + local bn + + if k == 'group' then + local nid = C.EC_GROUP_get_curve_name(group) + if nid == 0 then + return nil, "ec.get_parameters: EC_GROUP_get_curve_name() failed" + end + return nid + elseif k == 'public' or k == "pub_key" then + local pub_point = C.EC_KEY_get0_public_key(ec_key_st) + if pub_point == nil then + return nil, format_error("ec.get_parameters: EC_KEY_get0_public_key") + end + local point_form = C.EC_KEY_get_conv_form(ec_key_st) + if point_form == nil then + return nil, format_error("ec.get_parameters: EC_KEY_get_conv_form") + end + if BORINGSSL then + local sz = tonumber(C.EC_POINT_point2oct(group, pub_point, point_form, nil, 0, bn_lib.bn_ctx_tmp)) + if sz <= 0 then + return nil, format_error("ec.get_parameters: EC_POINT_point2oct") + end + local buf = ctypes.uchar_array(sz) + C.EC_POINT_point2oct(group, pub_point, point_form, buf, sz, bn_lib.bn_ctx_tmp) + buf = ffi.string(buf, sz) + local err + bn, err = bn_lib.from_binary(buf) + if bn == nil then + return nil, "ec.get_parameters: bn_lib.from_binary: " .. err + end + return bn + else + bn = C.EC_POINT_point2bn(group, pub_point, point_form, nil, bn_lib.bn_ctx_tmp) + if bn == nil then + return nil, format_error("ec.get_parameters: EC_POINT_point2bn") + end + ffi_gc(bn, C.BN_free) + end + elseif k == 'private' or k == "priv_key" then + -- get0, don't GC + bn = C.EC_KEY_get0_private_key(ec_key_st) + elseif k == 'x' or k == 'y' then + local pub_point = C.EC_KEY_get0_public_key(ec_key_st) + if pub_point == nil then + return nil, format_error("ec.get_parameters: EC_KEY_get0_public_key") + end + bn = C.BN_new() + if bn == nil then + return nil, "ec.get_parameters: BN_new() failed" + end + ffi_gc(bn, C.BN_free) + local f + if version_num >= 0x10101000 then + f = C.EC_POINT_get_affine_coordinates + else + f = C.EC_POINT_get_affine_coordinates_GFp + end + local code + if k == 'x' then + code = f(group, pub_point, bn, nil, bn_lib.bn_ctx_tmp) + else + code = f(group, pub_point, nil, bn, bn_lib.bn_ctx_tmp) + end + if code ~= 1 then + return nil, format_error("ec.get_parameters: EC_POINT_get_affine_coordinates") + end + else + return nil, "ec.get_parameters: unknown parameter \"" .. k .. "\" for EC key" + end + + if bn == nil then + return nil + end + return bn_lib.dup(bn) + end + }), nil +end + +function _M.set_parameters(ec_key_st, opts) + for _, k in ipairs(_M.params) do + if k ~= "group" then + if opts[k] and not bn_lib.istype(opts[k]) then + return nil, "expect parameter \"" .. k .. "\" to be a bn instance" + end + end + end + + local group_nid = opts["group"] + local group + if group_nid then + local nid, err = objects_lib.txtnid2nid(group_nid) + if err then + return nil, "ec.set_parameters: cannot use parameter \"group\":" .. err + end + + group = C.EC_GROUP_new_by_curve_name(nid) + if group == nil then + return nil, "ec.set_parameters: EC_GROUP_new_by_curve_name() failed" + end + ffi_gc(group, C.EC_GROUP_free) + -- # define OPENSSL_EC_NAMED_CURVE 0x001 + C.EC_GROUP_set_asn1_flag(group, 1) + C.EC_GROUP_set_point_conversion_form(group, C.POINT_CONVERSION_UNCOMPRESSED) + + if C.EC_KEY_set_group(ec_key_st, group) ~= 1 then + return nil, format_error("ec.set_parameters: EC_KEY_set_group") + end + end + + local x = opts["x"] + local y = opts["y"] + local pub = opts["public"] + if (x and not y) or (y and not x) then + return nil, "ec.set_parameters: \"x\" and \"y\" parameter must be defined at same time or both undefined" + end + + if x and y then + if pub then + return nil, "ec.set_parameters: cannot set \"x\" and \"y\" with \"public\" at same time to set public key" + end + -- double check if we have set group already + if group == nil then + group = C.EC_KEY_get0_group(ec_key_st) + if group == nil then + return nil, "ec.set_parameters: cannot set public key without setting \"group\"" + end + end + + if C.EC_KEY_set_public_key_affine_coordinates(ec_key_st, x.ctx, y.ctx) ~= 1 then + return nil, format_error("ec.set_parameters: EC_KEY_set_public_key_affine_coordinates") + end + end + + if pub then + if group == nil then + group = C.EC_KEY_get0_group(ec_key_st) + if group == nil then + return nil, "ec.set_parameters: cannot set public key without setting \"group\"" + end + end + + local point = C.EC_POINT_bn2point(group, pub.ctx, nil, bn_lib.bn_ctx_tmp) + if point == nil then + return nil, format_error("ec.set_parameters: EC_POINT_bn2point") + end + ffi_gc(point, C.EC_POINT_free) + + if C.EC_KEY_set_public_key(ec_key_st, point) ~= 1 then + return nil, format_error("ec.set_parameters: EC_KEY_set_public_key") + end + end + + local priv = opts["private"] + if priv then + -- openssl duplicates it inside + if C.EC_KEY_set_private_key(ec_key_st, priv.ctx) ~= 1 then + return nil, format_error("ec.set_parameters: EC_KEY_set_private_key") + end + end + +end + +return _M diff --git a/server/resty/openssl/ecx.lua b/server/resty/openssl/ecx.lua new file mode 100644 index 0000000..5ec7162 --- /dev/null +++ b/server/resty/openssl/ecx.lua @@ -0,0 +1,67 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_str = ffi.string + +require "resty.openssl.include.ec" +require "resty.openssl.include.evp" +local ctypes = require "resty.openssl.auxiliary.ctypes" +local format_error = require("resty.openssl.err").format_error + +local _M = {} + +_M.params = {"public", "private"} + +local empty_table = {} + +local MAX_ECX_KEY_SIZE = 114 -- ed448 uses 114 bytes + +function _M.get_parameters(evp_pkey_st) + return setmetatable(empty_table, { + __index = function(_, k) + local buf = ctypes.uchar_array(MAX_ECX_KEY_SIZE) + local length = ctypes.ptr_of_size_t(MAX_ECX_KEY_SIZE) + + if k == 'public' or k == "pub_key" then + if C.EVP_PKEY_get_raw_public_key(evp_pkey_st, buf, length) ~= 1 then + error(format_error("ecx.get_parameters: EVP_PKEY_get_raw_private_key")) + end + elseif k == 'private' or k == "priv ~=_key" then + if C.EVP_PKEY_get_raw_private_key(evp_pkey_st, buf, length) ~= 1 then + return nil, format_error("ecx.get_parameters: EVP_PKEY_get_raw_private_key") + end + else + return nil, "ecx.get_parameters: unknown parameter \"" .. k .. "\" for EC key" + end + return ffi_str(buf, length[0]) + end + }), nil +end + +function _M.set_parameters(key_type, evp_pkey_st, opts) + -- for ecx keys we always create a new EVP_PKEY and release the old one + -- Note: we allow to pass a nil as evp_pkey_st to create a new EVP_PKEY + local key + if opts.private then + local priv = opts.private + key = C.EVP_PKEY_new_raw_private_key(key_type, nil, priv, #priv) + if key == nil then + return nil, format_error("ecx.set_parameters: EVP_PKEY_new_raw_private_key") + end + elseif opts.public then + local pub = opts.public + key = C.EVP_PKEY_new_raw_public_key(key_type, nil, pub, #pub) + if key == nil then + return nil, format_error("ecx.set_parameters: EVP_PKEY_new_raw_public_key") + end + else + return nil, "no parameter is specified" + end + + if evp_pkey_st ~= nil then + C.EVP_PKEY_free(evp_pkey_st) + end + return key + +end + +return _M diff --git a/server/resty/openssl/err.lua b/server/resty/openssl/err.lua new file mode 100644 index 0000000..a047a7c --- /dev/null +++ b/server/resty/openssl/err.lua @@ -0,0 +1,62 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_str = ffi.string +local ffi_sizeof = ffi.sizeof + +local ctypes = require "resty.openssl.auxiliary.ctypes" +require "resty.openssl.include.err" + +local constchar_ptrptr = ffi.typeof("const char*[1]") + +local buf = ffi.new('char[256]') + +local function format_error(ctx, code, all_errors) + local errors = {} + if code then + table.insert(errors, string.format("code: %d", code or 0)) + end + -- get the OpenSSL errors + while C.ERR_peek_error() ~= 0 do + local line = ctypes.ptr_of_int() + local path = constchar_ptrptr() + local code + if all_errors then + code = C.ERR_get_error_line(path, line) + else + code = C.ERR_peek_last_error_line(path, line) + end + + local abs_path = ffi_str(path[0]) + -- ../crypto/asn1/a_d2i_fp.c => crypto/asn1/a_d2i_fp.c + local start = abs_path:find("/") + if start then + abs_path = abs_path:sub(start+1) + end + + C.ERR_error_string_n(code, buf, ffi_sizeof(buf)) + table.insert(errors, string.format("%s:%d:%s", + abs_path, line[0], ffi_str(buf)) + ) + + if not all_errors then + break + end + end + + C.ERR_clear_error() + + if #errors > 0 then + return string.format("%s%s%s", (ctx or ""), (ctx and ": " or ""), table.concat(errors, " ")) + else + return string.format("%s failed", ctx) + end +end + +local function format_all_error(ctx, code) + return format_error(ctx, code, true) +end + +return { + format_error = format_error, + format_all_error = format_all_error, +}
\ No newline at end of file diff --git a/server/resty/openssl/hmac.lua b/server/resty/openssl/hmac.lua new file mode 100644 index 0000000..fe18d2f --- /dev/null +++ b/server/resty/openssl/hmac.lua @@ -0,0 +1,90 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc +local ffi_str = ffi.string + +require "resty.openssl.include.hmac" +require "resty.openssl.include.evp.md" +local ctypes = require "resty.openssl.auxiliary.ctypes" +local format_error = require("resty.openssl.err").format_error +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_11_OR_LATER = require("resty.openssl.version").OPENSSL_11_OR_LATER +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X + +local _M = {} +local mt = {__index = _M} + +local hmac_ctx_ptr_ct = ffi.typeof('HMAC_CTX*') + +-- Note: https://www.openssl.org/docs/manmaster/man3/HMAC_Init.html +-- Replace with EVP_MAC_* functions for OpenSSL 3.0 + +function _M.new(key, typ) + local ctx + if OPENSSL_11_OR_LATER then + ctx = C.HMAC_CTX_new() + ffi_gc(ctx, C.HMAC_CTX_free) + elseif OPENSSL_10 then + ctx = ffi.new('HMAC_CTX') + C.HMAC_CTX_init(ctx) + ffi_gc(ctx, C.HMAC_CTX_cleanup) + end + if ctx == nil then + return nil, "hmac.new: failed to create HMAC_CTX" + end + + local algo = C.EVP_get_digestbyname(typ or 'sha1') + if algo == nil then + return nil, string.format("hmac.new: invalid digest type \"%s\"", typ) + end + + local code = C.HMAC_Init_ex(ctx, key, #key, algo, nil) + if code ~= 1 then + return nil, format_error("hmac.new") + end + + return setmetatable({ + ctx = ctx, + algo = algo, + buf = ctypes.uchar_array(OPENSSL_3X and C.EVP_MD_get_size(algo) or C.EVP_MD_size(algo)), + }, mt), nil +end + +function _M.istype(l) + return l and l.ctx and ffi.istype(hmac_ctx_ptr_ct, l.ctx) +end + +function _M:update(...) + for _, s in ipairs({...}) do + if C.HMAC_Update(self.ctx, s, #s) ~= 1 then + return false, format_error("hmac:update") + end + end + return true, nil +end + +local result_length = ctypes.ptr_of_uint() + +function _M:final(s) + if s then + if C.HMAC_Update(self.ctx, s, #s) ~= 1 then + return false, format_error("hmac:final") + end + end + + if C.HMAC_Final(self.ctx, self.buf, result_length) ~= 1 then + return nil, format_error("hmac:final: HMAC_Final") + end + return ffi_str(self.buf, result_length[0]) +end + +function _M:reset() + local code = C.HMAC_Init_ex(self.ctx, nil, 0, nil, nil) + if code ~= 1 then + return false, format_error("hmac:reset") + end + + return true +end + +return _M diff --git a/server/resty/openssl/include/asn1.lua b/server/resty/openssl/include/asn1.lua new file mode 100644 index 0000000..ba59ebc --- /dev/null +++ b/server/resty/openssl/include/asn1.lua @@ -0,0 +1,94 @@ +local ffi = require "ffi" +local C = ffi.C + +require "resty.openssl.include.ossl_typ" +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X + +ffi.cdef [[ + typedef struct ASN1_VALUE_st ASN1_VALUE; + + typedef struct asn1_type_st ASN1_TYPE; + + ASN1_IA5STRING *ASN1_IA5STRING_new(); + + int ASN1_STRING_type(const ASN1_STRING *x); + ASN1_STRING *ASN1_STRING_type_new(int type); + int ASN1_STRING_set(ASN1_STRING *str, const void *data, int len); + + ASN1_INTEGER *BN_to_ASN1_INTEGER(const BIGNUM *bn, ASN1_INTEGER *ai); + BIGNUM *ASN1_INTEGER_to_BN(const ASN1_INTEGER *ai, BIGNUM *bn); + + typedef int time_t; + ASN1_TIME *ASN1_TIME_set(ASN1_TIME *s, time_t t); + + int ASN1_INTEGER_set(ASN1_INTEGER *a, long v); + long ASN1_INTEGER_get(const ASN1_INTEGER *a); + int ASN1_ENUMERATED_set(ASN1_ENUMERATED *a, long v); + + int ASN1_STRING_print(BIO *bp, const ASN1_STRING *v); + + int ASN1_STRING_length(const ASN1_STRING *x); +]] + +local function declare_asn1_functions(typ, has_ex) + local t = {} + for i=1, 7 do + t[i] = typ + end + + ffi.cdef(string.format([[ + %s *%s_new(void); + void %s_free(%s *a); + %s *%s_dup(%s *a); + ]], unpack(t))) + + if OPENSSL_3X and has_ex then + ffi.cdef(string.format([[ + %s *%s_new_ex(OSSL_LIB_CTX *libctx, const char *propq); + ]], typ, typ)) + end +end + +declare_asn1_functions("ASN1_INTEGER") +declare_asn1_functions("ASN1_OBJECT") +declare_asn1_functions("ASN1_STRING") +declare_asn1_functions("ASN1_ENUMERATED") + +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_11_OR_LATER = require("resty.openssl.version").OPENSSL_11_OR_LATER +local BORINGSSL_110 = require("resty.openssl.version").BORINGSSL_110 + +local ASN1_STRING_get0_data +if OPENSSL_11_OR_LATER then + ffi.cdef[[ + const unsigned char *ASN1_STRING_get0_data(const ASN1_STRING *x); + ]] + ASN1_STRING_get0_data = C.ASN1_STRING_get0_data +elseif OPENSSL_10 then + ffi.cdef[[ + unsigned char *ASN1_STRING_data(ASN1_STRING *x); + typedef struct ASN1_ENCODING_st { + unsigned char *enc; /* DER encoding */ + long len; /* Length of encoding */ + int modified; /* set to 1 if 'enc' is invalid */ + } ASN1_ENCODING; + ]] + ASN1_STRING_get0_data = C.ASN1_STRING_data +end + +if BORINGSSL_110 then + ffi.cdef [[ + // required by resty/openssl/include/x509/crl.lua + typedef struct ASN1_ENCODING_st { + unsigned char *enc; /* DER encoding */ + long len; /* Length of encoding */ + int modified; /* set to 1 if 'enc' is invalid */ + } ASN1_ENCODING; + ]] +end + +return { + ASN1_STRING_get0_data = ASN1_STRING_get0_data, + declare_asn1_functions = declare_asn1_functions, + has_new_ex = true, +} diff --git a/server/resty/openssl/include/bio.lua b/server/resty/openssl/include/bio.lua new file mode 100644 index 0000000..45297fc --- /dev/null +++ b/server/resty/openssl/include/bio.lua @@ -0,0 +1,13 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" + +ffi.cdef [[ + typedef struct bio_method_st BIO_METHOD; + long BIO_ctrl(BIO *bp, int cmd, long larg, void *parg); + BIO *BIO_new_mem_buf(const void *buf, int len); + BIO *BIO_new(const BIO_METHOD *type); + int BIO_free(BIO *a); + const BIO_METHOD *BIO_s_mem(void); + int BIO_read(BIO *b, void *data, int dlen); +]] diff --git a/server/resty/openssl/include/bn.lua b/server/resty/openssl/include/bn.lua new file mode 100644 index 0000000..93d2dda --- /dev/null +++ b/server/resty/openssl/include/bn.lua @@ -0,0 +1,77 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X + +local BN_ULONG +if ffi.abi('64bit') then + BN_ULONG = 'unsigned long long' +else -- 32bit + BN_ULONG = 'unsigned int' +end + +ffi.cdef( +[[ + BIGNUM *BN_new(void); + void BN_free(BIGNUM *a); + + BN_CTX *BN_CTX_new(void); + void BN_CTX_init(BN_CTX *c); + void BN_CTX_free(BN_CTX *c); + + BIGNUM *BN_dup(const BIGNUM *a); + int BN_add_word(BIGNUM *a, ]] .. BN_ULONG ..[[ w); + int BN_set_word(BIGNUM *a, ]] .. BN_ULONG ..[[ w); + ]] .. BN_ULONG ..[[ BN_get_word(BIGNUM *a); + int BN_num_bits(const BIGNUM *a); + BIGNUM *BN_bin2bn(const unsigned char *s, int len, BIGNUM *ret); + int BN_hex2bn(BIGNUM **a, const char *str); + int BN_dec2bn(BIGNUM **a, const char *str); + int BN_bn2bin(const BIGNUM *a, unsigned char *to); + char *BN_bn2hex(const BIGNUM *a); + char *BN_bn2dec(const BIGNUM *a); + + void BN_set_negative(BIGNUM *a, int n); + int BN_is_negative(const BIGNUM *a); + + int BN_add(BIGNUM *r, const BIGNUM *a, const BIGNUM *b); + int BN_sub(BIGNUM *r, const BIGNUM *a, const BIGNUM *b); + int BN_mul(BIGNUM *r, BIGNUM *a, BIGNUM *b, BN_CTX *ctx); + int BN_sqr(BIGNUM *r, BIGNUM *a, BN_CTX *ctx); + int BN_div(BIGNUM *dv, BIGNUM *rem, const BIGNUM *a, const BIGNUM *d, + BN_CTX *ctx); + int BN_mod_add(BIGNUM *ret, BIGNUM *a, BIGNUM *b, const BIGNUM *m, + BN_CTX *ctx); + int BN_mod_sub(BIGNUM *ret, BIGNUM *a, BIGNUM *b, const BIGNUM *m, + BN_CTX *ctx); + int BN_mod_mul(BIGNUM *ret, BIGNUM *a, BIGNUM *b, const BIGNUM *m, + BN_CTX *ctx); + int BN_mod_sqr(BIGNUM *ret, BIGNUM *a, const BIGNUM *m, BN_CTX *ctx); + int BN_exp(BIGNUM *r, BIGNUM *a, BIGNUM *p, BN_CTX *ctx); + int BN_mod_exp(BIGNUM *r, BIGNUM *a, const BIGNUM *p, + const BIGNUM *m, BN_CTX *ctx); + int BN_gcd(BIGNUM *r, BIGNUM *a, BIGNUM *b, BN_CTX *ctx); + + int BN_lshift(BIGNUM *r, const BIGNUM *a, int n); + int BN_rshift(BIGNUM *r, BIGNUM *a, int n); + + int BN_cmp(BIGNUM *a, BIGNUM *b); + int BN_ucmp(BIGNUM *a, BIGNUM *b); + + // openssl >= 1.1 only + int BN_is_zero(BIGNUM *a); + int BN_is_one(BIGNUM *a); + int BN_is_word(BIGNUM *a, ]] .. BN_ULONG ..[[ w); + int BN_is_odd(BIGNUM *a); + + int BN_is_prime_ex(const BIGNUM *p,int nchecks, BN_CTX *ctx, BN_GENCB *cb); + int BN_generate_prime_ex(BIGNUM *ret,int bits,int safe, const BIGNUM *add, + const BIGNUM *rem, BN_GENCB *cb); +]] +) + +if OPENSSL_3X then + ffi.cdef [[ + int BN_check_prime(const BIGNUM *p, BN_CTX *ctx, BN_GENCB *cb); + ]] +end
\ No newline at end of file diff --git a/server/resty/openssl/include/conf.lua b/server/resty/openssl/include/conf.lua new file mode 100644 index 0000000..d655993 --- /dev/null +++ b/server/resty/openssl/include/conf.lua @@ -0,0 +1,9 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" + +ffi.cdef [[ + CONF *NCONF_new(CONF_METHOD *meth); + void NCONF_free(CONF *conf); + int NCONF_load_bio(CONF *conf, BIO *bp, long *eline); +]]
\ No newline at end of file diff --git a/server/resty/openssl/include/crypto.lua b/server/resty/openssl/include/crypto.lua new file mode 100644 index 0000000..6ca1f08 --- /dev/null +++ b/server/resty/openssl/include/crypto.lua @@ -0,0 +1,31 @@ +local ffi = require "ffi" +local C = ffi.C + +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_11_OR_LATER = require("resty.openssl.version").OPENSSL_11_OR_LATER + +local OPENSSL_free +if OPENSSL_10 then + ffi.cdef [[ + void CRYPTO_free(void *ptr); + ]] + OPENSSL_free = C.CRYPTO_free +elseif OPENSSL_11_OR_LATER then + ffi.cdef [[ + void CRYPTO_free(void *ptr, const char *file, int line); + ]] + OPENSSL_free = function(ptr) + -- file and line is for debuggin only, since we can't know the c file info + -- the macro is expanded, just ignore this + C.CRYPTO_free(ptr, "", 0) + end +end + +ffi.cdef [[ + int FIPS_mode(void); + int FIPS_mode_set(int ONOFF); +]] + +return { + OPENSSL_free = OPENSSL_free, +} diff --git a/server/resty/openssl/include/dh.lua b/server/resty/openssl/include/dh.lua new file mode 100644 index 0000000..504879d --- /dev/null +++ b/server/resty/openssl/include/dh.lua @@ -0,0 +1,80 @@ +local ffi = require "ffi" +local C = ffi.C + +require "resty.openssl.include.ossl_typ" +require "resty.openssl.include.objects" +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_11_OR_LATER = require("resty.openssl.version").OPENSSL_11_OR_LATER + +if OPENSSL_11_OR_LATER then + ffi.cdef [[ + void DH_get0_pqg(const DH *dh, + const BIGNUM **p, const BIGNUM **q, const BIGNUM **g); + int DH_set0_pqg(DH *dh, BIGNUM *p, BIGNUM *q, BIGNUM *g); + void DH_get0_key(const DH *dh, + const BIGNUM **pub_key, const BIGNUM **priv_key); + int DH_set0_key(DH *dh, BIGNUM *pub_key, BIGNUM *priv_key); + ]] +elseif OPENSSL_10 then + ffi.cdef [[ + struct dh_st { + /* + * This first argument is used to pick up errors when a DH is passed + * instead of a EVP_PKEY + */ + int pad; + int version; + BIGNUM *p; + BIGNUM *g; + long length; /* optional */ + BIGNUM *pub_key; /* g^x */ + BIGNUM *priv_key; /* x */ + int flags; + /*BN_MONT_CTX*/ void *method_mont_p; + /* Place holders if we want to do X9.42 DH */ + BIGNUM *q; + BIGNUM *j; + unsigned char *seed; + int seedlen; + BIGNUM *counter; + int references; + /* trimmer */ + // CRYPTO_EX_DATA ex_data; + // const DH_METHOD *meth; + // ENGINE *engine; + }; + ]] +end + +ffi.cdef [[ + DH *DH_get_1024_160(void); + DH *DH_get_2048_224(void); + DH *DH_get_2048_256(void); + DH *DH_new_by_nid(int nid); +]]; + + +local dh_groups = { + -- per https://tools.ietf.org/html/rfc5114 + dh_1024_160 = function() return C.DH_get_1024_160() end, + dh_2048_224 = function() return C.DH_get_2048_224() end, + dh_2048_256 = function() return C.DH_get_2048_256() end, +} + +local groups = { + "ffdhe2048", "ffdhe3072", "ffdhe4096", "ffdhe6144", "ffdhe8192", + "modp_2048", "modp_3072", "modp_4096", "modp_6144", "modp_8192", + -- following cannot be used with FIPS provider + "modp_1536", -- and the RFC5114 ones +} + +for _, group in ipairs(groups) do + local nid = C.OBJ_sn2nid(group) + if nid ~= 0 then + dh_groups[group] = function() return C.DH_new_by_nid(nid) end + end +end + +return { + dh_groups = dh_groups, +} diff --git a/server/resty/openssl/include/ec.lua b/server/resty/openssl/include/ec.lua new file mode 100644 index 0000000..674ef42 --- /dev/null +++ b/server/resty/openssl/include/ec.lua @@ -0,0 +1,59 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" + +ffi.cdef [[ + /** Enum for the point conversion form as defined in X9.62 (ECDSA) + * for the encoding of a elliptic curve point (x,y) */ + typedef enum { + /** the point is encoded as z||x, where the octet z specifies + * which solution of the quadratic equation y is */ + POINT_CONVERSION_COMPRESSED = 2, + /** the point is encoded as z||x||y, where z is the octet 0x04 */ + POINT_CONVERSION_UNCOMPRESSED = 4, + /** the point is encoded as z||x||y, where the octet z specifies + * which solution of the quadratic equation y is */ + POINT_CONVERSION_HYBRID = 6 + } point_conversion_form_t; + + EC_KEY *EC_KEY_new(void); + void EC_KEY_free(EC_KEY *key); + + EC_GROUP *EC_GROUP_new_by_curve_name(int nid); + void EC_GROUP_set_asn1_flag(EC_GROUP *group, int flag); + void EC_GROUP_set_point_conversion_form(EC_GROUP *group, + point_conversion_form_t form); + void EC_GROUP_set_curve_name(EC_GROUP *group, int nid); + int EC_GROUP_get_curve_name(const EC_GROUP *group); + + + void EC_GROUP_free(EC_GROUP *group); + + BIGNUM *EC_POINT_point2bn(const EC_GROUP *, const EC_POINT *, + point_conversion_form_t form, BIGNUM *, BN_CTX *); + // for BoringSSL + size_t EC_POINT_point2oct(const EC_GROUP *group, const EC_POINT *p, + point_conversion_form_t form, + unsigned char *buf, size_t len, BN_CTX *ctx); + // OpenSSL < 1.1.1 + int EC_POINT_get_affine_coordinates_GFp(const EC_GROUP *group, + const EC_POINT *p, + BIGNUM *x, BIGNUM *y, BN_CTX *ctx); + // OpenSSL >= 1.1.1 + int EC_POINT_get_affine_coordinates(const EC_GROUP *group, const EC_POINT *p, + BIGNUM *x, BIGNUM *y, BN_CTX *ctx); + EC_POINT *EC_POINT_bn2point(const EC_GROUP *group, const BIGNUM *bn, + EC_POINT *p, BN_CTX *ctx); + + point_conversion_form_t EC_KEY_get_conv_form(const EC_KEY *key); + + const BIGNUM *EC_KEY_get0_private_key(const EC_KEY *key); + int EC_KEY_set_private_key(EC_KEY *key, const BIGNUM *prv); + + const EC_POINT *EC_KEY_get0_public_key(const EC_KEY *key); + int EC_KEY_set_public_key(EC_KEY *key, const EC_POINT *pub); + int EC_KEY_set_public_key_affine_coordinates(EC_KEY *key, BIGNUM *x, BIGNUM *y); + + const EC_GROUP *EC_KEY_get0_group(const EC_KEY *key); + int EC_KEY_set_group(EC_KEY *key, const EC_GROUP *group); +]] diff --git a/server/resty/openssl/include/err.lua b/server/resty/openssl/include/err.lua new file mode 100644 index 0000000..142098c --- /dev/null +++ b/server/resty/openssl/include/err.lua @@ -0,0 +1,9 @@ +local ffi = require "ffi" + +ffi.cdef [[ + unsigned long ERR_peek_error(void); + unsigned long ERR_peek_last_error_line(const char **file, int *line); + unsigned long ERR_get_error_line(const char **file, int *line); + void ERR_clear_error(void); + void ERR_error_string_n(unsigned long e, char *buf, size_t len); +]] diff --git a/server/resty/openssl/include/evp.lua b/server/resty/openssl/include/evp.lua new file mode 100644 index 0000000..beeaf91 --- /dev/null +++ b/server/resty/openssl/include/evp.lua @@ -0,0 +1,109 @@ +local ffi = require "ffi" +local C = ffi.C +local bit = require("bit") + +require "resty.openssl.include.ossl_typ" +require "resty.openssl.include.err" +require "resty.openssl.include.objects" +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X +local BORINGSSL = require("resty.openssl.version").BORINGSSL + +if BORINGSSL then + ffi.cdef [[ + int PKCS5_PBKDF2_HMAC(const char *password, size_t password_len, + const uint8_t *salt, size_t salt_len, + unsigned iterations, const EVP_MD *digest, + size_t key_len, uint8_t *out_key); + int EVP_PBE_scrypt(const char *password, size_t password_len, + const uint8_t *salt, size_t salt_len, + uint64_t N, uint64_t r, uint64_t p, + size_t max_mem, uint8_t *out_key, + size_t key_len); + ]] +else + ffi.cdef [[ + /* KDF */ + int PKCS5_PBKDF2_HMAC(const char *pass, int passlen, + const unsigned char *salt, int saltlen, int iter, + const EVP_MD *digest, int keylen, unsigned char *out); + + int EVP_PBE_scrypt(const char *pass, size_t passlen, + const unsigned char *salt, size_t saltlen, + uint64_t N, uint64_t r, uint64_t p, uint64_t maxmem, + unsigned char *key, size_t keylen); + ]] +end + +if OPENSSL_3X then + require "resty.openssl.include.provider" + + ffi.cdef [[ + int EVP_set_default_properties(OSSL_LIB_CTX *libctx, const char *propq); + int EVP_default_properties_enable_fips(OSSL_LIB_CTX *libctx, int enable); + int EVP_default_properties_is_fips_enabled(OSSL_LIB_CTX *libctx); + + // const OSSL_PROVIDER *EVP_RAND_get0_provider(const EVP_RAND *rand); + // EVP_RAND *EVP_RAND_fetch(OSSL_LIB_CTX *libctx, const char *algorithm, + // const char *properties); + ]] +end + +local EVP_PKEY_ALG_CTRL = 0x1000 + +local _M = { + EVP_PKEY_RSA = C.OBJ_txt2nid("rsaEncryption"), + EVP_PKEY_DH = C.OBJ_txt2nid("dhKeyAgreement"), + EVP_PKEY_EC = C.OBJ_txt2nid("id-ecPublicKey"), + EVP_PKEY_X25519 = C.OBJ_txt2nid("X25519"), + EVP_PKEY_ED25519 = C.OBJ_txt2nid("ED25519"), + EVP_PKEY_X448 = C.OBJ_txt2nid("X448"), + EVP_PKEY_ED448 = C.OBJ_txt2nid("ED448"), + + EVP_PKEY_OP_PARAMGEN = bit.lshift(1, 1), + EVP_PKEY_OP_KEYGEN = bit.lshift(1, 2), + EVP_PKEY_OP_SIGN = bit.lshift(1, 3), + EVP_PKEY_OP_VERIFY = bit.lshift(1, 4), + EVP_PKEY_OP_DERIVE = OPENSSL_3X and bit.lshift(1, 12) or bit.lshift(1, 10), + + EVP_PKEY_ALG_CTRL = EVP_PKEY_ALG_CTRL, + + + EVP_PKEY_CTRL_DH_PARAMGEN_PRIME_LEN = EVP_PKEY_ALG_CTRL + 1, + EVP_PKEY_CTRL_EC_PARAMGEN_CURVE_NID = EVP_PKEY_ALG_CTRL + 1, + EVP_PKEY_CTRL_EC_PARAM_ENC = EVP_PKEY_ALG_CTRL + 2, + EVP_PKEY_CTRL_RSA_KEYGEN_BITS = EVP_PKEY_ALG_CTRL + 3, + EVP_PKEY_CTRL_RSA_KEYGEN_PUBEXP = EVP_PKEY_ALG_CTRL + 4, + EVP_PKEY_CTRL_RSA_PADDING = EVP_PKEY_ALG_CTRL + 1, + EVP_PKEY_CTRL_RSA_PSS_SALTLEN = EVP_PKEY_ALG_CTRL + 2, + + EVP_CTRL_AEAD_SET_IVLEN = 0x9, + EVP_CTRL_AEAD_GET_TAG = 0x10, + EVP_CTRL_AEAD_SET_TAG = 0x11, + + EVP_PKEY_CTRL_TLS_MD = EVP_PKEY_ALG_CTRL, + EVP_PKEY_CTRL_TLS_SECRET = EVP_PKEY_ALG_CTRL + 1, + EVP_PKEY_CTRL_TLS_SEED = EVP_PKEY_ALG_CTRL + 2, + EVP_PKEY_CTRL_HKDF_MD = EVP_PKEY_ALG_CTRL + 3, + EVP_PKEY_CTRL_HKDF_SALT = EVP_PKEY_ALG_CTRL + 4, + EVP_PKEY_CTRL_HKDF_KEY = EVP_PKEY_ALG_CTRL + 5, + EVP_PKEY_CTRL_HKDF_INFO = EVP_PKEY_ALG_CTRL + 6, + EVP_PKEY_CTRL_HKDF_MODE = EVP_PKEY_ALG_CTRL + 7, + EVP_PKEY_CTRL_PASS = EVP_PKEY_ALG_CTRL + 8, + EVP_PKEY_CTRL_SCRYPT_SALT = EVP_PKEY_ALG_CTRL + 9, + EVP_PKEY_CTRL_SCRYPT_N = EVP_PKEY_ALG_CTRL + 10, + EVP_PKEY_CTRL_SCRYPT_R = EVP_PKEY_ALG_CTRL + 11, + EVP_PKEY_CTRL_SCRYPT_P = EVP_PKEY_ALG_CTRL + 12, + EVP_PKEY_CTRL_SCRYPT_MAXMEM_BYTES = EVP_PKEY_ALG_CTRL + 13, +} + +-- clean up error occurs during OBJ_txt2* +C.ERR_clear_error() + +_M.ecx_curves = { + Ed25519 = _M.EVP_PKEY_ED25519, + X25519 = _M.EVP_PKEY_X25519, + Ed448 = _M.EVP_PKEY_ED448, + X448 = _M.EVP_PKEY_X448, +} + +return _M diff --git a/server/resty/openssl/include/evp/cipher.lua b/server/resty/openssl/include/evp/cipher.lua new file mode 100644 index 0000000..c803766 --- /dev/null +++ b/server/resty/openssl/include/evp/cipher.lua @@ -0,0 +1,123 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_11_OR_LATER = require("resty.openssl.version").OPENSSL_11_OR_LATER +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X +local BORINGSSL = require("resty.openssl.version").BORINGSSL + +ffi.cdef [[ + // openssl < 3.0 + int EVP_CIPHER_CTX_block_size(const EVP_CIPHER_CTX *ctx); + int EVP_CIPHER_CTX_key_length(const EVP_CIPHER_CTX *ctx); + int EVP_CIPHER_CTX_iv_length(const EVP_CIPHER_CTX *ctx); + int EVP_CIPHER_CTX_set_padding(EVP_CIPHER_CTX *c, int pad); + + const EVP_CIPHER *EVP_CIPHER_CTX_cipher(const EVP_CIPHER_CTX *ctx); + const EVP_CIPHER *EVP_get_cipherbyname(const char *name); + int EVP_CIPHER_CTX_ctrl(EVP_CIPHER_CTX *ctx, int type, int arg, void *ptr); + int EVP_EncryptUpdate(EVP_CIPHER_CTX *ctx, unsigned char *out, + int *outl, const unsigned char *in, int inl); + int EVP_DecryptUpdate(EVP_CIPHER_CTX *ctx, unsigned char *out, + int *outl, const unsigned char *in, int inl); + + + int EVP_CipherInit_ex(EVP_CIPHER_CTX *ctx, + const EVP_CIPHER *cipher, ENGINE *impl, + const unsigned char *key, + const unsigned char *iv, int enc); + int EVP_CipherUpdate(EVP_CIPHER_CTX *ctx, unsigned char *out, + int *outl, const unsigned char *in, int inl); + int EVP_CipherFinal_ex(EVP_CIPHER_CTX *ctx, unsigned char *outm, + int *outl); + + // list functions + typedef void* fake_openssl_cipher_list_fn(const EVP_CIPHER *ciph, const char *from, + const char *to, void *x); + //void EVP_CIPHER_do_all_sorted(fake_openssl_cipher_list_fn*, void *arg); + void EVP_CIPHER_do_all_sorted(void (*fn) + (const EVP_CIPHER *ciph, const char *from, + const char *to, void *x), void *arg); + int EVP_CIPHER_nid(const EVP_CIPHER *cipher); +]] + +if BORINGSSL then + ffi.cdef [[ + int EVP_BytesToKey(const EVP_CIPHER *type, const EVP_MD *md, + const uint8_t *salt, const uint8_t *data, + size_t data_len, unsigned count, uint8_t *key, + uint8_t *iv); + ]] +else + ffi.cdef [[ + int EVP_BytesToKey(const EVP_CIPHER *type, const EVP_MD *md, + const unsigned char *salt, + const unsigned char *data, int datal, int count, + unsigned char *key, unsigned char *iv); + ]] +end + +if OPENSSL_3X then + require "resty.openssl.include.provider" + + ffi.cdef [[ + int EVP_CIPHER_CTX_get_block_size(const EVP_CIPHER_CTX *ctx); + int EVP_CIPHER_CTX_get_key_length(const EVP_CIPHER_CTX *ctx); + int EVP_CIPHER_CTX_get_iv_length(const EVP_CIPHER_CTX *ctx); + + int EVP_CIPHER_get_nid(const EVP_CIPHER *cipher); + + const OSSL_PROVIDER *EVP_CIPHER_get0_provider(const EVP_CIPHER *cipher); + EVP_CIPHER *EVP_CIPHER_fetch(OSSL_LIB_CTX *ctx, const char *algorithm, + const char *properties); + + typedef void* fake_openssl_cipher_provided_list_fn(EVP_CIPHER *cipher, void *arg); + void EVP_CIPHER_do_all_provided(OSSL_LIB_CTX *libctx, + fake_openssl_cipher_provided_list_fn*, + void *arg); + int EVP_CIPHER_up_ref(EVP_CIPHER *cipher); + void EVP_CIPHER_free(EVP_CIPHER *cipher); + + const char *EVP_CIPHER_get0_name(const EVP_CIPHER *cipher); + + int EVP_CIPHER_CTX_set_params(EVP_CIPHER_CTX *ctx, const OSSL_PARAM params[]); + const OSSL_PARAM *EVP_CIPHER_CTX_settable_params(EVP_CIPHER_CTX *ctx); + int EVP_CIPHER_CTX_get_params(EVP_CIPHER_CTX *ctx, OSSL_PARAM params[]); + const OSSL_PARAM *EVP_CIPHER_CTX_gettable_params(EVP_CIPHER_CTX *ctx); + ]] +end + +if OPENSSL_11_OR_LATER then + ffi.cdef [[ + EVP_CIPHER_CTX *EVP_CIPHER_CTX_new(void); + int EVP_CIPHER_CTX_reset(EVP_CIPHER_CTX *c); + void EVP_CIPHER_CTX_free(EVP_CIPHER_CTX *c); + ]] +elseif OPENSSL_10 then + ffi.cdef [[ + void EVP_CIPHER_CTX_init(EVP_CIPHER_CTX *a); + int EVP_CIPHER_CTX_cleanup(EVP_CIPHER_CTX *a); + + // # define EVP_MAX_IV_LENGTH 16 + // # define EVP_MAX_BLOCK_LENGTH 32 + + struct evp_cipher_ctx_st { + const EVP_CIPHER *cipher; + ENGINE *engine; /* functional reference if 'cipher' is + * ENGINE-provided */ + int encrypt; /* encrypt or decrypt */ + int buf_len; /* number we have left */ + unsigned char oiv[16]; /* original iv EVP_MAX_IV_LENGTH */ + unsigned char iv[16]; /* working iv EVP_MAX_IV_LENGTH */ + unsigned char buf[32]; /* saved partial block EVP_MAX_BLOCK_LENGTH */ + int num; /* used by cfb/ofb/ctr mode */ + void *app_data; /* application stuff */ + int key_len; /* May change for variable length cipher */ + unsigned long flags; /* Various flags */ + void *cipher_data; /* per EVP data */ + int final_used; + int block_mask; + unsigned char final[32]; /* possible final block EVP_MAX_BLOCK_LENGTH */ + } /* EVP_CIPHER_CTX */ ; + ]] +end
\ No newline at end of file diff --git a/server/resty/openssl/include/evp/kdf.lua b/server/resty/openssl/include/evp/kdf.lua new file mode 100644 index 0000000..1fd408f --- /dev/null +++ b/server/resty/openssl/include/evp/kdf.lua @@ -0,0 +1,148 @@ +local ffi = require "ffi" +local ffi_cast = ffi.cast +local C = ffi.C + +require "resty.openssl.include.ossl_typ" +require "resty.openssl.include.evp.md" +local evp = require("resty.openssl.include.evp") +local ctypes = require "resty.openssl.auxiliary.ctypes" +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X +local BORINGSSL = require("resty.openssl.version").BORINGSSL + +local void_ptr = ctypes.void_ptr + +local _M = { + EVP_PKEY_HKDEF_MODE_EXTRACT_AND_EXPAND = 0, + EVP_PKEY_HKDEF_MODE_EXTRACT_ONLY = 1, + EVP_PKEY_HKDEF_MODE_EXPAND_ONLY = 2, +} + +if OPENSSL_3X then + require "resty.openssl.include.provider" + + ffi.cdef [[ + const OSSL_PROVIDER *EVP_KDF_get0_provider(const EVP_KDF *kdf); + + typedef void* fake_openssl_kdf_provided_list_fn(EVP_KDF *kdf, void *arg); + void EVP_KDF_do_all_provided(OSSL_LIB_CTX *libctx, + fake_openssl_kdf_provided_list_fn*, + void *arg); + int EVP_KDF_up_ref(EVP_KDF *kdf); + void EVP_KDF_free(EVP_KDF *kdf); + + const char *EVP_KDF_get0_name(const EVP_KDF *kdf); + + EVP_KDF *EVP_KDF_fetch(OSSL_LIB_CTX *libctx, const char *algorithm, + const char *properties); + EVP_KDF_CTX *EVP_KDF_CTX_new(const EVP_KDF *kdf); + void EVP_KDF_CTX_free(EVP_KDF_CTX *ctx); + void EVP_KDF_CTX_reset(EVP_KDF_CTX *ctx); + + size_t EVP_KDF_CTX_get_kdf_size(EVP_KDF_CTX *ctx); + int EVP_KDF_derive(EVP_KDF_CTX *ctx, unsigned char *key, size_t keylen, + const OSSL_PARAM params[]); + + int EVP_KDF_CTX_get_params(EVP_KDF_CTX *ctx, OSSL_PARAM params[]); + int EVP_KDF_CTX_set_params(EVP_KDF_CTX *ctx, const OSSL_PARAM params[]); + const OSSL_PARAM *EVP_KDF_CTX_gettable_params(const EVP_KDF_CTX *ctx); + const OSSL_PARAM *EVP_KDF_CTX_settable_params(const EVP_KDF_CTX *ctx); + ]] +end + +if OPENSSL_3X or BORINGSSL then + ffi.cdef [[ + int EVP_PKEY_CTX_set_tls1_prf_md(EVP_PKEY_CTX *ctx, const EVP_MD *md); + int EVP_PKEY_CTX_set1_tls1_prf_secret(EVP_PKEY_CTX *pctx, + const unsigned char *sec, int seclen); + int EVP_PKEY_CTX_add1_tls1_prf_seed(EVP_PKEY_CTX *pctx, + const unsigned char *seed, int seedlen); + + int EVP_PKEY_CTX_set_hkdf_md(EVP_PKEY_CTX *ctx, const EVP_MD *md); + int EVP_PKEY_CTX_set1_hkdf_salt(EVP_PKEY_CTX *ctx, + const unsigned char *salt, int saltlen); + int EVP_PKEY_CTX_set1_hkdf_key(EVP_PKEY_CTX *ctx, + const unsigned char *key, int keylen); + int EVP_PKEY_CTX_set_hkdf_mode(EVP_PKEY_CTX *ctx, int mode); + int EVP_PKEY_CTX_add1_hkdf_info(EVP_PKEY_CTX *ctx, + const unsigned char *info, int infolen); + ]] + + _M.EVP_PKEY_CTX_set_tls1_prf_md = function(pctx, md) + return C.EVP_PKEY_CTX_set_tls1_prf_md(pctx, md) + end + _M.EVP_PKEY_CTX_set1_tls1_prf_secret = function(pctx, sec) + return C.EVP_PKEY_CTX_set1_tls1_prf_secret(pctx, sec, #sec) + end + _M.EVP_PKEY_CTX_add1_tls1_prf_seed = function(pctx, seed) + return C.EVP_PKEY_CTX_add1_tls1_prf_seed(pctx, seed, #seed) + end + + _M.EVP_PKEY_CTX_set_hkdf_md = function(pctx, md) + return C.EVP_PKEY_CTX_set_hkdf_md(pctx, md) + end + _M.EVP_PKEY_CTX_set1_hkdf_salt = function(pctx, salt) + return C.EVP_PKEY_CTX_set1_hkdf_salt(pctx, salt, #salt) + end + _M.EVP_PKEY_CTX_set1_hkdf_key = function(pctx, key) + return C.EVP_PKEY_CTX_set1_hkdf_key(pctx, key, #key) + end + _M.EVP_PKEY_CTX_set_hkdf_mode = function(pctx, mode) + return C.EVP_PKEY_CTX_set_hkdf_mode(pctx, mode) + end + _M.EVP_PKEY_CTX_add1_hkdf_info = function(pctx, info) + return C.EVP_PKEY_CTX_add1_hkdf_info(pctx, info, #info) + end + +else + _M.EVP_PKEY_CTX_set_tls1_prf_md = function(pctx, md) + return C.EVP_PKEY_CTX_ctrl(pctx, -1, + evp.EVP_PKEY_OP_DERIVE, + evp.EVP_PKEY_CTRL_TLS_MD, + 0, ffi_cast(void_ptr, md)) + end + _M.EVP_PKEY_CTX_set1_tls1_prf_secret = function(pctx, sec) + return C.EVP_PKEY_CTX_ctrl(pctx, -1, + evp.EVP_PKEY_OP_DERIVE, + evp.EVP_PKEY_CTRL_TLS_SECRET, + #sec, ffi_cast(void_ptr, sec)) + end + _M.EVP_PKEY_CTX_add1_tls1_prf_seed = function(pctx, seed) + return C.EVP_PKEY_CTX_ctrl(pctx, -1, + evp.EVP_PKEY_OP_DERIVE, + evp.EVP_PKEY_CTRL_TLS_SEED, + #seed, ffi_cast(void_ptr, seed)) + end + + _M.EVP_PKEY_CTX_set_hkdf_md = function(pctx, md) + return C.EVP_PKEY_CTX_ctrl(pctx, -1, + evp.EVP_PKEY_OP_DERIVE, + evp.EVP_PKEY_CTRL_HKDF_MD, + 0, ffi_cast(void_ptr, md)) + end + _M.EVP_PKEY_CTX_set1_hkdf_salt = function(pctx, salt) + return C.EVP_PKEY_CTX_ctrl(pctx, -1, + evp.EVP_PKEY_OP_DERIVE, + evp.EVP_PKEY_CTRL_HKDF_SALT, + #salt, ffi_cast(void_ptr, salt)) + end + _M.EVP_PKEY_CTX_set1_hkdf_key = function(pctx, key) + return C.EVP_PKEY_CTX_ctrl(pctx, -1, + evp.EVP_PKEY_OP_DERIVE, + evp.EVP_PKEY_CTRL_HKDF_KEY, + #key, ffi_cast(void_ptr, key)) + end + _M.EVP_PKEY_CTX_set_hkdf_mode = function(pctx, mode) + return C.EVP_PKEY_CTX_ctrl(pctx, -1, + evp.EVP_PKEY_OP_DERIVE, + evp.EVP_PKEY_CTRL_HKDF_MODE, + mode, nil) + end + _M.EVP_PKEY_CTX_add1_hkdf_info = function(pctx, info) + return C.EVP_PKEY_CTX_ctrl(pctx, -1, + evp.EVP_PKEY_OP_DERIVE, + evp.EVP_PKEY_CTRL_HKDF_INFO, + #info, ffi_cast(void_ptr, info)) + end +end + +return _M
\ No newline at end of file diff --git a/server/resty/openssl/include/evp/mac.lua b/server/resty/openssl/include/evp/mac.lua new file mode 100644 index 0000000..a831076 --- /dev/null +++ b/server/resty/openssl/include/evp/mac.lua @@ -0,0 +1,38 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" +require "resty.openssl.include.provider" + +ffi.cdef [[ + typedef struct evp_mac_st EVP_MAC; + typedef struct evp_mac_ctx_st EVP_MAC_CTX; + + EVP_MAC_CTX *EVP_MAC_CTX_new(EVP_MAC *mac); + void EVP_MAC_CTX_free(EVP_MAC_CTX *ctx); + + const OSSL_PROVIDER *EVP_MAC_get0_provider(const EVP_MAC *mac); + EVP_MAC *EVP_MAC_fetch(OSSL_LIB_CTX *libctx, const char *algorithm, + const char *properties); + + int EVP_MAC_init(EVP_MAC_CTX *ctx, const unsigned char *key, size_t keylen, + const OSSL_PARAM params[]); + int EVP_MAC_update(EVP_MAC_CTX *ctx, const unsigned char *data, size_t datalen); + int EVP_MAC_final(EVP_MAC_CTX *ctx, + unsigned char *out, size_t *outl, size_t outsize); + + size_t EVP_MAC_CTX_get_mac_size(EVP_MAC_CTX *ctx); + + typedef void* fake_openssl_mac_provided_list_fn(EVP_MAC *mac, void *arg); + void EVP_MAC_do_all_provided(OSSL_LIB_CTX *libctx, + fake_openssl_mac_provided_list_fn*, + void *arg); + int EVP_MAC_up_ref(EVP_MAC *mac); + void EVP_MAC_free(EVP_MAC *mac); + + const char *EVP_MAC_get0_name(const EVP_MAC *mac); + + int EVP_MAC_CTX_set_params(EVP_MAC_CTX *ctx, const OSSL_PARAM params[]); + const OSSL_PARAM *EVP_MAC_CTX_settable_params(EVP_MAC_CTX *ctx); + int EVP_MAC_CTX_get_params(EVP_MAC_CTX *ctx, OSSL_PARAM params[]); + const OSSL_PARAM *EVP_MAC_CTX_gettable_params(EVP_MAC_CTX *ctx); +]]
\ No newline at end of file diff --git a/server/resty/openssl/include/evp/md.lua b/server/resty/openssl/include/evp/md.lua new file mode 100644 index 0000000..1794ce1 --- /dev/null +++ b/server/resty/openssl/include/evp/md.lua @@ -0,0 +1,86 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_11_OR_LATER = require("resty.openssl.version").OPENSSL_11_OR_LATER +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X + +ffi.cdef [[ + int EVP_DigestInit_ex(EVP_MD_CTX *ctx, const EVP_MD *type, + ENGINE *impl); + int EVP_DigestUpdate(EVP_MD_CTX *ctx, const void *d, + size_t cnt); + int EVP_DigestFinal_ex(EVP_MD_CTX *ctx, unsigned char *md, + unsigned int *s); + const EVP_MD *EVP_get_digestbyname(const char *name); + int EVP_DigestUpdate(EVP_MD_CTX *ctx, const void *d, + size_t cnt); + int EVP_DigestFinal_ex(EVP_MD_CTX *ctx, unsigned char *md, + unsigned int *s); + + const EVP_MD *EVP_md_null(void); + // openssl < 3.0 + int EVP_MD_size(const EVP_MD *md); + int EVP_MD_type(const EVP_MD *md); + + typedef void* fake_openssl_md_list_fn(const EVP_MD *ciph, const char *from, + const char *to, void *x); + void EVP_MD_do_all_sorted(fake_openssl_md_list_fn*, void *arg); + + const EVP_MD *EVP_get_digestbyname(const char *name); +]] + +if OPENSSL_3X then + require "resty.openssl.include.provider" + + ffi.cdef [[ + int EVP_MD_get_size(const EVP_MD *md); + int EVP_MD_get_type(const EVP_MD *md); + const OSSL_PROVIDER *EVP_MD_get0_provider(const EVP_MD *md); + + EVP_MD *EVP_MD_fetch(OSSL_LIB_CTX *ctx, const char *algorithm, + const char *properties); + + typedef void* fake_openssl_md_provided_list_fn(EVP_MD *md, void *arg); + void EVP_MD_do_all_provided(OSSL_LIB_CTX *libctx, + fake_openssl_md_provided_list_fn*, + void *arg); + int EVP_MD_up_ref(EVP_MD *md); + void EVP_MD_free(EVP_MD *md); + + const char *EVP_MD_get0_name(const EVP_MD *md); + + int EVP_MD_CTX_set_params(EVP_MD_CTX *ctx, const OSSL_PARAM params[]); + const OSSL_PARAM *EVP_MD_CTX_settable_params(EVP_MD_CTX *ctx); + int EVP_MD_CTX_get_params(EVP_MD_CTX *ctx, OSSL_PARAM params[]); + const OSSL_PARAM *EVP_MD_CTX_gettable_params(EVP_MD_CTX *ctx); + ]] +end + +if OPENSSL_11_OR_LATER then + ffi.cdef [[ + EVP_MD_CTX *EVP_MD_CTX_new(void); + void EVP_MD_CTX_free(EVP_MD_CTX *ctx); + ]] +elseif OPENSSL_10 then + ffi.cdef [[ + EVP_MD_CTX *EVP_MD_CTX_create(void); + void EVP_MD_CTX_destroy(EVP_MD_CTX *ctx); + + // crypto/evp/evp.h + // only needed for openssl 1.0.x where initializer for HMAC_CTX is not avaiable + // HACK: renamed from env_md_ctx_st to evp_md_ctx_st to match typedef (lazily) + // it's an internal struct thus name is not exported so we will be fine + struct evp_md_ctx_st { + const EVP_MD *digest; + ENGINE *engine; /* functional reference if 'digest' is + * ENGINE-provided */ + unsigned long flags; + void *md_data; + /* Public key context for sign/verify */ + EVP_PKEY_CTX *pctx; + /* Update function: usually copied from EVP_MD */ + int (*update) (EVP_MD_CTX *ctx, const void *data, size_t count); + } /* EVP_MD_CTX */ ; + ]] +end
\ No newline at end of file diff --git a/server/resty/openssl/include/evp/pkey.lua b/server/resty/openssl/include/evp/pkey.lua new file mode 100644 index 0000000..ee1a213 --- /dev/null +++ b/server/resty/openssl/include/evp/pkey.lua @@ -0,0 +1,234 @@ +local ffi = require "ffi" +local C = ffi.C + +require "resty.openssl.include.ossl_typ" +require "resty.openssl.include.evp.md" +local evp = require("resty.openssl.include.evp") +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X +local BORINGSSL = require("resty.openssl.version").BORINGSSL + +ffi.cdef [[ + EVP_PKEY *EVP_PKEY_new(void); + void EVP_PKEY_free(EVP_PKEY *pkey); + + RSA *EVP_PKEY_get0_RSA(EVP_PKEY *pkey); + EC_KEY *EVP_PKEY_get0_EC_KEY(EVP_PKEY *pkey); + DH *EVP_PKEY_get0_DH(EVP_PKEY *pkey); + + int EVP_PKEY_assign(EVP_PKEY *pkey, int type, void *key); + // openssl < 3.0 + int EVP_PKEY_base_id(const EVP_PKEY *pkey); + int EVP_PKEY_size(const EVP_PKEY *pkey); + + EVP_PKEY_CTX *EVP_PKEY_CTX_new(EVP_PKEY *pkey, ENGINE *e); + EVP_PKEY_CTX *EVP_PKEY_CTX_new_id(int id, ENGINE *e); + void EVP_PKEY_CTX_free(EVP_PKEY_CTX *ctx); + int EVP_PKEY_CTX_ctrl(EVP_PKEY_CTX *ctx, int keytype, int optype, + int cmd, int p1, void *p2); + // TODO replace EVP_PKEY_CTX_ctrl with EVP_PKEY_CTX_ctrl_str to reduce + // some hardcoded macros + int EVP_PKEY_CTX_ctrl_str(EVP_PKEY_CTX *ctx, const char *type, + const char *value); + int EVP_PKEY_encrypt_init(EVP_PKEY_CTX *ctx); + int EVP_PKEY_encrypt(EVP_PKEY_CTX *ctx, + unsigned char *out, size_t *outlen, + const unsigned char *in, size_t inlen); + int EVP_PKEY_decrypt_init(EVP_PKEY_CTX *ctx); + int EVP_PKEY_decrypt(EVP_PKEY_CTX *ctx, + unsigned char *out, size_t *outlen, + const unsigned char *in, size_t inlen); + + int EVP_PKEY_sign_init(EVP_PKEY_CTX *ctx); + int EVP_PKEY_sign(EVP_PKEY_CTX *ctx, + unsigned char *sig, size_t *siglen, + const unsigned char *tbs, size_t tbslen); + int EVP_PKEY_verify_recover_init(EVP_PKEY_CTX *ctx); + int EVP_PKEY_verify_recover(EVP_PKEY_CTX *ctx, + unsigned char *rout, size_t *routlen, + const unsigned char *sig, size_t siglen); + + EVP_PKEY *EVP_PKEY_new_raw_private_key(int type, ENGINE *e, + const unsigned char *key, size_t keylen); + EVP_PKEY *EVP_PKEY_new_raw_public_key(int type, ENGINE *e, + const unsigned char *key, size_t keylen); + + int EVP_PKEY_get_raw_private_key(const EVP_PKEY *pkey, unsigned char *priv, + size_t *len); + int EVP_PKEY_get_raw_public_key(const EVP_PKEY *pkey, unsigned char *pub, + size_t *len); + + int EVP_SignFinal(EVP_MD_CTX *ctx, unsigned char *md, unsigned int *s, + EVP_PKEY *pkey); + int EVP_VerifyFinal(EVP_MD_CTX *ctx, const unsigned char *sigbuf, + unsigned int siglen, EVP_PKEY *pkey); + + int EVP_DigestSignInit(EVP_MD_CTX *ctx, EVP_PKEY_CTX **pctx, + const EVP_MD *type, ENGINE *e, EVP_PKEY *pkey); + int EVP_DigestSign(EVP_MD_CTX *ctx, unsigned char *sigret, + size_t *siglen, const unsigned char *tbs, + size_t tbslen); + int EVP_DigestVerifyInit(EVP_MD_CTX *ctx, EVP_PKEY_CTX **pctx, + const EVP_MD *type, ENGINE *e, EVP_PKEY *pkey); + int EVP_DigestVerify(EVP_MD_CTX *ctx, const unsigned char *sigret, + size_t siglen, const unsigned char *tbs, size_t tbslen); + + int EVP_PKEY_get_default_digest_nid(EVP_PKEY *pkey, int *pnid); + + int EVP_PKEY_derive_init(EVP_PKEY_CTX *ctx); + int EVP_PKEY_derive_set_peer(EVP_PKEY_CTX *ctx, EVP_PKEY *peer); + int EVP_PKEY_derive(EVP_PKEY_CTX *ctx, unsigned char *key, size_t *keylen); + + int EVP_PKEY_keygen_init(EVP_PKEY_CTX *ctx); + int EVP_PKEY_keygen(EVP_PKEY_CTX *ctx, EVP_PKEY **ppkey); + int EVP_PKEY_paramgen_init(EVP_PKEY_CTX *ctx); + int EVP_PKEY_paramgen(EVP_PKEY_CTX *ctx, EVP_PKEY **ppkey); +]] + +if OPENSSL_3X then + require "resty.openssl.include.provider" + + ffi.cdef [[ + int EVP_PKEY_CTX_set_rsa_padding(EVP_PKEY_CTX *ctx, int pad_mode); + + int EVP_PKEY_get_base_id(const EVP_PKEY *pkey); + int EVP_PKEY_get_size(const EVP_PKEY *pkey); + + const OSSL_PROVIDER *EVP_PKEY_get0_provider(const EVP_PKEY *key); + const OSSL_PROVIDER *EVP_PKEY_CTX_get0_provider(const EVP_PKEY_CTX *ctx); + + const OSSL_PARAM *EVP_PKEY_settable_params(const EVP_PKEY *pkey); + int EVP_PKEY_set_params(EVP_PKEY *pkey, OSSL_PARAM params[]); + int EVP_PKEY_get_params(EVP_PKEY *ctx, OSSL_PARAM params[]); + const OSSL_PARAM *EVP_PKEY_gettable_params(EVP_PKEY *ctx); + ]] +end + +if OPENSSL_10 then + ffi.cdef [[ + // crypto/evp/evp.h + // only needed for openssl 1.0.x where getters are not available + // needed to get key to extract parameters + // Note: this struct is trimmed + struct evp_pkey_st { + int type; + int save_type; + const EVP_PKEY_ASN1_METHOD *ameth; + ENGINE *engine; + ENGINE *pmeth_engine; + union { + void *ptr; + struct rsa_st *rsa; + struct dsa_st *dsa; + struct dh_st *dh; + struct ec_key_st *ec; + } pkey; + // trimmed + + // CRYPTO_REF_COUNT references; + // CRYPTO_RWLOCK *lock; + // STACK_OF(X509_ATTRIBUTE) *attributes; + // int save_parameters; + + // struct { + // EVP_KEYMGMT *keymgmt; + // void *provkey; + // } pkeys[10]; + // size_t dirty_cnt_copy; + }; + ]] +end + +local _M = {} + +if OPENSSL_3X or BORINGSSL then + ffi.cdef [[ + int EVP_PKEY_CTX_set_ec_paramgen_curve_nid(EVP_PKEY_CTX *ctx, int nid); + int EVP_PKEY_CTX_set_ec_param_enc(EVP_PKEY_CTX *ctx, int param_enc); + + int EVP_PKEY_CTX_set_rsa_keygen_bits(EVP_PKEY_CTX *ctx, int mbits); + int EVP_PKEY_CTX_set_rsa_keygen_pubexp(EVP_PKEY_CTX *ctx, BIGNUM *pubexp); + + int EVP_PKEY_CTX_set_rsa_padding(EVP_PKEY_CTX *ctx, int pad); + int EVP_PKEY_CTX_set_rsa_pss_saltlen(EVP_PKEY_CTX *ctx, int len); + + int EVP_PKEY_CTX_set_dh_paramgen_prime_len(EVP_PKEY_CTX *ctx, int pbits); + ]] + _M.EVP_PKEY_CTX_set_ec_paramgen_curve_nid = function(pctx, nid) + return C.EVP_PKEY_CTX_set_ec_paramgen_curve_nid(pctx, nid) + end + _M.EVP_PKEY_CTX_set_ec_param_enc = function(pctx, param_enc) + return C.EVP_PKEY_CTX_set_ec_param_enc(pctx, param_enc) + end + + _M.EVP_PKEY_CTX_set_rsa_keygen_bits = function(pctx, mbits) + return C.EVP_PKEY_CTX_set_rsa_keygen_bits(pctx, mbits) + end + _M.EVP_PKEY_CTX_set_rsa_keygen_pubexp = function(pctx, pubexp) + return C.EVP_PKEY_CTX_set_rsa_keygen_pubexp(pctx, pubexp) + end + + _M.EVP_PKEY_CTX_set_rsa_padding = function(pctx, pad) + return C.EVP_PKEY_CTX_set_rsa_padding(pctx, pad) + end + _M.EVP_PKEY_CTX_set_rsa_pss_saltlen = function(pctx, len) + return C.EVP_PKEY_CTX_set_rsa_pss_saltlen(pctx, len) + end + _M.EVP_PKEY_CTX_set_dh_paramgen_prime_len = function(pctx, pbits) + return C.EVP_PKEY_CTX_set_dh_paramgen_prime_len(pctx, pbits) + end + +else + _M.EVP_PKEY_CTX_set_ec_paramgen_curve_nid = function(pctx, nid) + return C.EVP_PKEY_CTX_ctrl(pctx, + evp.EVP_PKEY_EC, + evp.EVP_PKEY_OP_PARAMGEN + evp.EVP_PKEY_OP_KEYGEN, + evp.EVP_PKEY_CTRL_EC_PARAMGEN_CURVE_NID, + nid, nil) + end + _M.EVP_PKEY_CTX_set_ec_param_enc = function(pctx, param_enc) + return C.EVP_PKEY_CTX_ctrl(pctx, + evp.EVP_PKEY_EC, + evp.EVP_PKEY_OP_PARAMGEN + evp.EVP_PKEY_OP_KEYGEN, + evp.EVP_PKEY_CTRL_EC_PARAM_ENC, + param_enc, nil) + end + + _M.EVP_PKEY_CTX_set_rsa_keygen_bits = function(pctx, mbits) + return C.EVP_PKEY_CTX_ctrl(pctx, + evp.EVP_PKEY_RSA, + evp.EVP_PKEY_OP_KEYGEN, + evp.EVP_PKEY_CTRL_RSA_KEYGEN_BITS, + mbits, nil) + end + _M.EVP_PKEY_CTX_set_rsa_keygen_pubexp = function(pctx, pubexp) + return C.EVP_PKEY_CTX_ctrl(pctx, + evp.EVP_PKEY_RSA, evp.EVP_PKEY_OP_KEYGEN, + evp.EVP_PKEY_CTRL_RSA_KEYGEN_PUBEXP, + 0, pubexp) + end + + _M.EVP_PKEY_CTX_set_rsa_padding = function(pctx, pad) + return C.EVP_PKEY_CTX_ctrl(pctx, + evp.EVP_PKEY_RSA, + -1, + evp.EVP_PKEY_CTRL_RSA_PADDING, + pad, nil) + end + _M.EVP_PKEY_CTX_set_rsa_pss_saltlen = function(pctx, len) + return C.EVP_PKEY_CTX_ctrl(pctx, + evp.EVP_PKEY_RSA, + evp.EVP_PKEY_OP_SIGN + evp.EVP_PKEY_OP_VERIFY, + evp.EVP_PKEY_CTRL_RSA_PSS_SALTLEN, + len, nil) + end + + _M.EVP_PKEY_CTX_set_dh_paramgen_prime_len = function(pctx, pbits) + return C.EVP_PKEY_CTX_ctrl(pctx, + evp.EVP_PKEY_DH, evp.EVP_PKEY_OP_PARAMGEN, + evp.EVP_PKEY_CTRL_DH_PARAMGEN_PRIME_LEN, + pbits, nil) + end +end + +return _M
\ No newline at end of file diff --git a/server/resty/openssl/include/hmac.lua b/server/resty/openssl/include/hmac.lua new file mode 100644 index 0000000..e08f031 --- /dev/null +++ b/server/resty/openssl/include/hmac.lua @@ -0,0 +1,48 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" +require "resty.openssl.include.evp" +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_11_OR_LATER = require("resty.openssl.version").OPENSSL_11_OR_LATER +local BORINGSSL = require("resty.openssl.version").BORINGSSL + +if BORINGSSL then + ffi.cdef [[ + int HMAC_Init_ex(HMAC_CTX *ctx, const void *key, size_t key_len, + const EVP_MD *md, ENGINE *impl); + ]] +else + ffi.cdef [[ + int HMAC_Init_ex(HMAC_CTX *ctx, const void *key, int len, + const EVP_MD *md, ENGINE *impl); + ]] +end + +ffi.cdef [[ + int HMAC_Update(HMAC_CTX *ctx, const unsigned char *data, + size_t len); + int HMAC_Final(HMAC_CTX *ctx, unsigned char *md, + unsigned int *len); +]] + +if OPENSSL_11_OR_LATER then + ffi.cdef [[ + HMAC_CTX *HMAC_CTX_new(void); + void HMAC_CTX_free(HMAC_CTX *ctx); + ]] +elseif OPENSSL_10 then + ffi.cdef [[ + // # define HMAC_MAX_MD_CBLOCK 128/* largest known is SHA512 */ + struct hmac_ctx_st { + const EVP_MD *md; + EVP_MD_CTX md_ctx; + EVP_MD_CTX i_ctx; + EVP_MD_CTX o_ctx; + unsigned int key_length; + unsigned char key[128]; + }; + + void HMAC_CTX_init(HMAC_CTX *ctx); + void HMAC_CTX_cleanup(HMAC_CTX *ctx); + ]] +end
\ No newline at end of file diff --git a/server/resty/openssl/include/objects.lua b/server/resty/openssl/include/objects.lua new file mode 100644 index 0000000..aecd324 --- /dev/null +++ b/server/resty/openssl/include/objects.lua @@ -0,0 +1,19 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" + +ffi.cdef [[ + int OBJ_obj2txt(char *buf, int buf_len, const ASN1_OBJECT *a, int no_name); + ASN1_OBJECT *OBJ_txt2obj(const char *s, int no_name); + int OBJ_txt2nid(const char *s); + const char *OBJ_nid2sn(int n); + int OBJ_ln2nid(const char *s); + int OBJ_sn2nid(const char *s); + const char *OBJ_nid2ln(int n); + const char *OBJ_nid2sn(int n); + int OBJ_obj2nid(const ASN1_OBJECT *o); + const ASN1_OBJECT *OBJ_nid2obj(int n); + int OBJ_create(const char *oid, const char *sn, const char *ln); + + int OBJ_find_sigid_algs(int signid, int *pdig_nid, int *ppkey_nid); +]] diff --git a/server/resty/openssl/include/ossl_typ.lua b/server/resty/openssl/include/ossl_typ.lua new file mode 100644 index 0000000..198c889 --- /dev/null +++ b/server/resty/openssl/include/ossl_typ.lua @@ -0,0 +1,71 @@ +local ffi = require "ffi" + +ffi.cdef( +[[ + typedef struct rsa_st RSA; + typedef struct evp_pkey_st EVP_PKEY; + typedef struct bignum_st BIGNUM; + typedef struct bn_gencb_st BN_GENCB; + typedef struct bignum_ctx BN_CTX; + typedef struct bio_st BIO; + typedef struct evp_cipher_st EVP_CIPHER; + typedef struct evp_md_ctx_st EVP_MD_CTX; + typedef struct evp_pkey_ctx_st EVP_PKEY_CTX; + typedef struct evp_md_st EVP_MD; + typedef struct evp_pkey_asn1_method_st EVP_PKEY_ASN1_METHOD; + typedef struct evp_cipher_ctx_st EVP_CIPHER_CTX; + typedef struct engine_st ENGINE; + typedef struct x509_st X509; + typedef struct x509_attributes_st X509_ATTRIBUTE; + typedef struct X509_extension_st X509_EXTENSION; + typedef struct X509_name_st X509_NAME; + typedef struct X509_name_entry_st X509_NAME_ENTRY; + typedef struct X509_req_st X509_REQ; + typedef struct X509_crl_st X509_CRL; + typedef struct x509_store_st X509_STORE; + typedef struct x509_store_ctx_st X509_STORE_CTX; + typedef struct x509_purpose_st X509_PURPOSE; + typedef struct v3_ext_ctx X509V3_CTX; + typedef struct asn1_string_st ASN1_INTEGER; + typedef struct asn1_string_st ASN1_ENUMERATED; + typedef struct asn1_string_st ASN1_BIT_STRING; + typedef struct asn1_string_st ASN1_OCTET_STRING; + typedef struct asn1_string_st ASN1_PRINTABLESTRING; + typedef struct asn1_string_st ASN1_T61STRING; + typedef struct asn1_string_st ASN1_IA5STRING; + typedef struct asn1_string_st ASN1_GENERALSTRING; + typedef struct asn1_string_st ASN1_UNIVERSALSTRING; + typedef struct asn1_string_st ASN1_BMPSTRING; + typedef struct asn1_string_st ASN1_UTCTIME; + typedef struct asn1_string_st ASN1_TIME; + typedef struct asn1_string_st ASN1_GENERALIZEDTIME; + typedef struct asn1_string_st ASN1_VISIBLESTRING; + typedef struct asn1_string_st ASN1_UTF8STRING; + typedef struct asn1_string_st ASN1_STRING; + typedef struct asn1_object_st ASN1_OBJECT; + typedef struct conf_st CONF; + typedef struct conf_method_st CONF_METHOD; + typedef int ASN1_BOOLEAN; + typedef int ASN1_NULL; + typedef struct ec_key_st EC_KEY; + typedef struct ec_method_st EC_METHOD; + typedef struct ec_point_st EC_POINT; + typedef struct ec_group_st EC_GROUP; + typedef struct rsa_meth_st RSA_METHOD; + // typedef struct evp_keymgmt_st EVP_KEYMGMT; + // typedef struct crypto_ex_data_st CRYPTO_EX_DATA; + // typedef struct bn_mont_ctx_st BN_MONT_CTX; + // typedef struct bn_blinding_st BN_BLINDING; + // crypto.h + // typedef void CRYPTO_RWLOCK; + typedef struct hmac_ctx_st HMAC_CTX; + typedef struct x509_revoked_st X509_REVOKED; + typedef struct dh_st DH; + typedef struct PKCS12_st PKCS12; + typedef struct ssl_st SSL; + typedef struct ssl_ctx_st SSL_CTX; + typedef struct evp_kdf_st EVP_KDF; + typedef struct evp_kdf_ctx_st EVP_KDF_CTX; + typedef struct ossl_lib_ctx_st OSSL_LIB_CTX; +]]) + diff --git a/server/resty/openssl/include/param.lua b/server/resty/openssl/include/param.lua new file mode 100644 index 0000000..9c7a2e9 --- /dev/null +++ b/server/resty/openssl/include/param.lua @@ -0,0 +1,71 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" + +ffi.cdef [[ + typedef struct ossl_param_st { + const char *key; /* the name of the parameter */ + unsigned int data_type; /* declare what kind of content is in buffer */ + void *data; /* value being passed in or out */ + size_t data_size; /* data size */ + size_t return_size; /* returned content size */ + } OSSL_PARAM; + + OSSL_PARAM OSSL_PARAM_construct_int(const char *key, int *buf); + OSSL_PARAM OSSL_PARAM_construct_uint(const char *key, unsigned int *buf); + OSSL_PARAM OSSL_PARAM_construct_BN(const char *key, unsigned char *buf, + size_t bsize); + OSSL_PARAM OSSL_PARAM_construct_double(const char *key, double *buf); + OSSL_PARAM OSSL_PARAM_construct_utf8_string(const char *key, char *buf, + size_t bsize); + OSSL_PARAM OSSL_PARAM_construct_octet_string(const char *key, void *buf, + size_t bsize); + OSSL_PARAM OSSL_PARAM_construct_utf8_ptr(const char *key, char **buf, + size_t bsize); + OSSL_PARAM OSSL_PARAM_construct_octet_ptr(const char *key, void **buf, + size_t bsize); + OSSL_PARAM OSSL_PARAM_construct_end(void); + + int OSSL_PARAM_get_int32(const OSSL_PARAM *p, int32_t *val); + int OSSL_PARAM_get_uint32(const OSSL_PARAM *p, uint32_t *val); + int OSSL_PARAM_get_int64(const OSSL_PARAM *p, int64_t *val); + int OSSL_PARAM_get_uint64(const OSSL_PARAM *p, uint64_t *val); + // int OSSL_PARAM_get_size_t(const OSSL_PARAM *p, size_t *val); + // int OSSL_PARAM_get_time_t(const OSSL_PARAM *p, time_t *val); + + int OSSL_PARAM_set_int(OSSL_PARAM *p, int val); + int OSSL_PARAM_set_uint(OSSL_PARAM *p, unsigned int val); + int OSSL_PARAM_set_long(OSSL_PARAM *p, long int val); + int OSSL_PARAM_set_ulong(OSSL_PARAM *p, unsigned long int val); + int OSSL_PARAM_set_int32(OSSL_PARAM *p, int32_t val); + int OSSL_PARAM_set_uint32(OSSL_PARAM *p, uint32_t val); + int OSSL_PARAM_set_int64(OSSL_PARAM *p, int64_t val); + int OSSL_PARAM_set_uint64(OSSL_PARAM *p, uint64_t val); + // int OSSL_PARAM_set_size_t(OSSL_PARAM *p, size_t val); + // int OSSL_PARAM_set_time_t(OSSL_PARAM *p, time_t val); + + int OSSL_PARAM_get_double(const OSSL_PARAM *p, double *val); + int OSSL_PARAM_set_double(OSSL_PARAM *p, double val); + + int OSSL_PARAM_get_BN(const OSSL_PARAM *p, BIGNUM **val); + int OSSL_PARAM_set_BN(OSSL_PARAM *p, const BIGNUM *val); + + int OSSL_PARAM_get_utf8_string(const OSSL_PARAM *p, char **val, size_t max_len); + int OSSL_PARAM_set_utf8_string(OSSL_PARAM *p, const char *val); + + int OSSL_PARAM_get_octet_string(const OSSL_PARAM *p, void **val, size_t max_len, + size_t *used_len); + int OSSL_PARAM_set_octet_string(OSSL_PARAM *p, const void *val, size_t len); + + int OSSL_PARAM_get_utf8_ptr(const OSSL_PARAM *p, const char **val); + int OSSL_PARAM_set_utf8_ptr(OSSL_PARAM *p, const char *val); + + int OSSL_PARAM_get_octet_ptr(const OSSL_PARAM *p, const void **val, + size_t *used_len); + int OSSL_PARAM_set_octet_ptr(OSSL_PARAM *p, const void *val, + size_t used_len); + + int OSSL_PARAM_get_utf8_string_ptr(const OSSL_PARAM *p, const char **val); + int OSSL_PARAM_get_octet_string_ptr(const OSSL_PARAM *p, const void **val, + size_t *used_len); +]] diff --git a/server/resty/openssl/include/pem.lua b/server/resty/openssl/include/pem.lua new file mode 100644 index 0000000..50185e5 --- /dev/null +++ b/server/resty/openssl/include/pem.lua @@ -0,0 +1,50 @@ + +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" + +ffi.cdef [[ + // all pem_password_cb* has been modified to pem_password_cb to avoid a table overflow issue + typedef int (*pem_password_cb)(char *buf, int size, int rwflag, void *userdata); + EVP_PKEY *PEM_read_bio_PrivateKey(BIO *bp, EVP_PKEY **x, + // the following signature has been modified to avoid ffi.cast + pem_password_cb cb, const char *u); + // pem_password_cb *cb, void *u); + EVP_PKEY *PEM_read_bio_PUBKEY(BIO *bp, EVP_PKEY **x, + // the following signature has been modified to avoid ffi.cast + pem_password_cb cb, const char *u); + // pem_password_cb *cb, void *u); + int PEM_write_bio_PrivateKey(BIO *bp, EVP_PKEY *x, const EVP_CIPHER *enc, + unsigned char *kstr, int klen, + pem_password_cb *cb, void *u); + int PEM_write_bio_PUBKEY(BIO *bp, EVP_PKEY *x); + + RSA *PEM_read_bio_RSAPrivateKey(BIO *bp, RSA **x, + // the following signature has been modified to avoid ffi.cast + pem_password_cb cb, const char *u); + // pem_password_cb *cb, void *u); + RSA *PEM_read_bio_RSAPublicKey(BIO *bp, RSA **x, + // the following signature has been modified to avoid ffi.cast + pem_password_cb cb, const char *u); + // pem_password_cb *cb, void *u); + int PEM_write_bio_RSAPrivateKey(BIO *bp, RSA *x, const EVP_CIPHER *enc, + unsigned char *kstr, int klen, + pem_password_cb *cb, void *u); + int PEM_write_bio_RSAPublicKey(BIO *bp, RSA *x); + + X509_REQ *PEM_read_bio_X509_REQ(BIO *bp, X509_REQ **x, pem_password_cb cb, void *u); + int PEM_write_bio_X509_REQ(BIO *bp, X509_REQ *x); + + X509_CRL *PEM_read_bio_X509_CRL(BIO *bp, X509_CRL **x, pem_password_cb cb, void *u); + int PEM_write_bio_X509_CRL(BIO *bp, X509_CRL *x); + + X509 *PEM_read_bio_X509(BIO *bp, X509 **x, pem_password_cb cb, void *u); + int PEM_write_bio_X509(BIO *bp, X509 *x); + + DH *PEM_read_bio_DHparams(BIO *bp, DH **x, pem_password_cb cb, void *u); + int PEM_write_bio_DHparams(BIO *bp, DH *x); + + EC_GROUP *PEM_read_bio_ECPKParameters(BIO *bp, EC_GROUP **x, pem_password_cb cb, void *u); + int PEM_write_bio_ECPKParameters(BIO *bp, const EC_GROUP *x); + +]] diff --git a/server/resty/openssl/include/pkcs12.lua b/server/resty/openssl/include/pkcs12.lua new file mode 100644 index 0000000..fb74025 --- /dev/null +++ b/server/resty/openssl/include/pkcs12.lua @@ -0,0 +1,31 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" +require "resty.openssl.include.stack" + +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X + +ffi.cdef [[ + // hack by changing char* to const char* here + PKCS12 *PKCS12_create(const char *pass, const char *name, EVP_PKEY *pkey, X509 *cert, + OPENSSL_STACK *ca, // STACK_OF(X509) + int nid_key, int nid_cert, int iter, int mac_iter, int keytype); + + int PKCS12_parse(PKCS12 *p12, const char *pass, EVP_PKEY **pkey, X509 **cert, + OPENSSL_STACK **ca); // STACK_OF(X509) **ca); + + void PKCS12_free(PKCS12 *p12); + int i2d_PKCS12_bio(BIO *bp, PKCS12 *a); + PKCS12 *d2i_PKCS12_bio(BIO *bp, PKCS12 **a); +]] + +if OPENSSL_3X then + ffi.cdef [[ + PKCS12 *PKCS12_create_ex(const char *pass, const char *name, EVP_PKEY *pkey, + X509 *cert, + OPENSSL_STACK *ca, // STACK_OF(X509) + int nid_key, int nid_cert, + int iter, int mac_iter, int keytype, + OSSL_LIB_CTX *ctx, const char *propq); + ]] +end diff --git a/server/resty/openssl/include/provider.lua b/server/resty/openssl/include/provider.lua new file mode 100644 index 0000000..a2bb472 --- /dev/null +++ b/server/resty/openssl/include/provider.lua @@ -0,0 +1,27 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" +require "resty.openssl.include.param" + +ffi.cdef [[ + typedef struct ossl_provider_st OSSL_PROVIDER; + typedef struct ossl_lib_ctx_st OSSL_LIB_CTX; + + void OSSL_PROVIDER_set_default_search_path(OSSL_LIB_CTX *libctx, + const char *path); + + + OSSL_PROVIDER *OSSL_PROVIDER_load(OSSL_LIB_CTX *libctx, const char *name); + OSSL_PROVIDER *OSSL_PROVIDER_try_load(OSSL_LIB_CTX *libctx, const char *name); + int OSSL_PROVIDER_unload(OSSL_PROVIDER *prov); + int OSSL_PROVIDER_available(OSSL_LIB_CTX *libctx, const char *name); + + const OSSL_PARAM *OSSL_PROVIDER_gettable_params(OSSL_PROVIDER *prov); + int OSSL_PROVIDER_get_params(OSSL_PROVIDER *prov, OSSL_PARAM params[]); + + // int OSSL_PROVIDER_add_builtin(OSSL_LIB_CTX *libctx, const char *name, + // ossl_provider_init_fn *init_fn); + + const char *OSSL_PROVIDER_get0_name(const OSSL_PROVIDER *prov); + int OSSL_PROVIDER_self_test(const OSSL_PROVIDER *prov); +]] diff --git a/server/resty/openssl/include/rand.lua b/server/resty/openssl/include/rand.lua new file mode 100644 index 0000000..90f44c1 --- /dev/null +++ b/server/resty/openssl/include/rand.lua @@ -0,0 +1,24 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X +local BORINGSSL = require("resty.openssl.version").BORINGSSL + +if BORINGSSL then + ffi.cdef [[ + int RAND_bytes(uint8_t *buf, size_t num); + int RAND_priv_bytes(uint8_t *buf, size_t num); + ]] +elseif OPENSSL_3X then + ffi.cdef [[ + int RAND_bytes_ex(OSSL_LIB_CTX *ctx, unsigned char *buf, size_t num, + unsigned int strength); + int RAND_priv_bytes_ex(OSSL_LIB_CTX *ctx, unsigned char *buf, size_t num, + unsigned int strength); + ]] +else + ffi.cdef [[ + int RAND_bytes(unsigned char *buf, int num); + int RAND_priv_bytes(unsigned char *buf, int num); + ]] +end diff --git a/server/resty/openssl/include/rsa.lua b/server/resty/openssl/include/rsa.lua new file mode 100644 index 0000000..d7de5f4 --- /dev/null +++ b/server/resty/openssl/include/rsa.lua @@ -0,0 +1,70 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_11_OR_LATER = require("resty.openssl.version").OPENSSL_11_OR_LATER + +ffi.cdef [[ + RSA *RSA_new(void); + void RSA_free(RSA *r); +]] + +if OPENSSL_11_OR_LATER then + ffi.cdef [[ + void RSA_get0_key(const RSA *r, + const BIGNUM **n, const BIGNUM **e, const BIGNUM **d); + void RSA_get0_factors(const RSA *r, const BIGNUM **p, const BIGNUM **q); + void RSA_get0_crt_params(const RSA *r, + const BIGNUM **dmp1, const BIGNUM **dmq1, + const BIGNUM **iqmp); + + int RSA_set0_key(RSA *r, BIGNUM *n, BIGNUM *e, BIGNUM *d); + int RSA_set0_factors(RSA *r, BIGNUM *p, BIGNUM *q); + int RSA_set0_crt_params(RSA *r,BIGNUM *dmp1, BIGNUM *dmq1, BIGNUM *iqmp); + struct rsa_st; + ]] +elseif OPENSSL_10 then + ffi.cdef [[ + // crypto/rsa/rsa_locl.h + // needed to extract parameters + // Note: this struct is trimmed + struct rsa_st { + int pad; + // the following has been changed in OpenSSL 1.1.x to int32_t + long version; + const RSA_METHOD *meth; + ENGINE *engine; + BIGNUM *n; + BIGNUM *e; + BIGNUM *d; + BIGNUM *p; + BIGNUM *q; + BIGNUM *dmp1; + BIGNUM *dmq1; + BIGNUM *iqmp; + // trimmed + + // CRYPTO_EX_DATA ex_data; + // int references; + // int flags; + // BN_MONT_CTX *_method_mod_n; + // BN_MONT_CTX *_method_mod_p; + // BN_MONT_CTX *_method_mod_q; + + // char *bignum_data; + // BN_BLINDING *blinding; + // BN_BLINDING *mt_blinding; + }; + ]] +end + +return { + paddings = { + RSA_PKCS1_PADDING = 1, + RSA_SSLV23_PADDING = 2, + RSA_NO_PADDING = 3, + RSA_PKCS1_OAEP_PADDING = 4, + RSA_X931_PADDING = 5, + RSA_PKCS1_PSS_PADDING = 6, + }, +} diff --git a/server/resty/openssl/include/ssl.lua b/server/resty/openssl/include/ssl.lua new file mode 100644 index 0000000..1219ac3 --- /dev/null +++ b/server/resty/openssl/include/ssl.lua @@ -0,0 +1,113 @@ +local ffi = require "ffi" +local C = ffi.C + +require "resty.openssl.include.ossl_typ" +require "resty.openssl.include.stack" +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X +local BORINGSSL = require("resty.openssl.version").BORINGSSL + +ffi.cdef [[ + // SSL_METHOD + typedef struct ssl_method_st SSL_METHOD; + const SSL_METHOD *TLS_method(void); + const SSL_METHOD *TLS_server_method(void); + + // SSL_CIPHER + typedef struct ssl_cipher_st SSL_CIPHER; + const char *SSL_CIPHER_get_name(const SSL_CIPHER *cipher); + SSL_CIPHER *SSL_get_current_cipher(const SSL *ssl); + + SSL_CTX *SSL_CTX_new(const SSL_METHOD *meth); + void SSL_CTX_free(SSL_CTX *a); + + // SSL_SESSION + typedef struct ssl_session_st SSL_SESSION; + SSL_SESSION *SSL_get_session(const SSL *ssl); + long SSL_SESSION_set_timeout(SSL_SESSION *s, long t); + long SSL_SESSION_get_timeout(const SSL_SESSION *s); + + typedef int (*SSL_CTX_alpn_select_cb_func)(SSL *ssl, + const unsigned char **out, + unsigned char *outlen, + const unsigned char *in, + unsigned int inlen, + void *arg); + void SSL_CTX_set_alpn_select_cb(SSL_CTX *ctx, + SSL_CTX_alpn_select_cb_func cb, + void *arg); + + int SSL_select_next_proto(unsigned char **out, unsigned char *outlen, + const unsigned char *server, + unsigned int server_len, + const unsigned char *client, + unsigned int client_len); + + SSL *SSL_new(SSL_CTX *ctx); + void SSL_free(SSL *ssl); + + int SSL_set_cipher_list(SSL *ssl, const char *str); + int SSL_set_ciphersuites(SSL *s, const char *str); + + long SSL_set_options(SSL *ssl, long options); + long SSL_clear_options(SSL *ssl, long options); + long SSL_get_options(SSL *ssl); + + /*STACK_OF(SSL_CIPHER)*/ OPENSSL_STACK *SSL_get_ciphers(const SSL *ssl); + /*STACK_OF(SSL_CIPHER)*/ OPENSSL_STACK *SSL_CTX_get_ciphers(const SSL_CTX *ctx); + OPENSSL_STACK *SSL_get_peer_cert_chain(const SSL *ssl); + + typedef int (*verify_callback)(int preverify_ok, X509_STORE_CTX *x509_ctx); + void SSL_set_verify(SSL *s, int mode, + int (*verify_callback)(int, X509_STORE_CTX *)); + + int SSL_add_client_CA(SSL *ssl, X509 *cacert); + + long SSL_ctrl(SSL *ssl, int cmd, long larg, void *parg); +]] + +if OPENSSL_3X then + ffi.cdef [[ + X509 *SSL_get1_peer_certificate(const SSL *ssl); + ]] +else + ffi.cdef [[ + X509 *SSL_get_peer_certificate(const SSL *ssl); + ]] +end + +if BORINGSSL then + ffi.cdef [[ + int SSL_set_min_proto_version(SSL *ssl, int version); + int SSL_set_max_proto_version(SSL *ssl, int version); + ]] +end + +local SSL_CTRL_SET_MIN_PROTO_VERSION = 123 +local SSL_CTRL_SET_MAX_PROTO_VERSION = 124 + +local SSL_set_min_proto_version +if BORINGSSL then + SSL_set_min_proto_version = function(ctx, version) + return C.SSL_set_min_proto_version(ctx, version) + end +else + SSL_set_min_proto_version = function(ctx, version) + return C.SSL_ctrl(ctx, SSL_CTRL_SET_MIN_PROTO_VERSION, version, nil) + end +end + +local SSL_set_max_proto_version +if BORINGSSL then + SSL_set_max_proto_version = function(ctx, version) + return C.SSL_set_max_proto_version(ctx, version) + end +else + SSL_set_max_proto_version = function(ctx, version) + return C.SSL_ctrl(ctx, SSL_CTRL_SET_MAX_PROTO_VERSION, version, nil) + end +end + +return { + SSL_set_min_proto_version = SSL_set_min_proto_version, + SSL_set_max_proto_version = SSL_set_max_proto_version, +} diff --git a/server/resty/openssl/include/stack.lua b/server/resty/openssl/include/stack.lua new file mode 100644 index 0000000..5732608 --- /dev/null +++ b/server/resty/openssl/include/stack.lua @@ -0,0 +1,95 @@ +--[[ + The OpenSSL stack library. Note `safestack` is not usable here in ffi because + those symbols are eaten after preprocessing. + Instead, we should do a Lua land type checking by having a nested field indicating + which type of cdata its ctx holds. +]] + +local ffi = require "ffi" +local C = ffi.C + +require "resty.openssl.include.ossl_typ" +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_11_OR_LATER = require("resty.openssl.version").OPENSSL_11_OR_LATER +local BORINGSSL = require("resty.openssl.version").BORINGSSL + +local _M = {} + +ffi.cdef [[ + typedef char *OPENSSL_STRING; +]] + +if OPENSSL_11_OR_LATER and not BORINGSSL then + ffi.cdef [[ + typedef struct stack_st OPENSSL_STACK; + + OPENSSL_STACK *OPENSSL_sk_new_null(void); + int OPENSSL_sk_push(OPENSSL_STACK *st, const void *data); + void OPENSSL_sk_pop_free(OPENSSL_STACK *st, void (*func) (void *)); + int OPENSSL_sk_num(const OPENSSL_STACK *); + void *OPENSSL_sk_value(const OPENSSL_STACK *, int); + OPENSSL_STACK *OPENSSL_sk_dup(const OPENSSL_STACK *st); + void OPENSSL_sk_free(OPENSSL_STACK *); + void *OPENSSL_sk_delete(OPENSSL_STACK *st, int loc); + + typedef void (*OPENSSL_sk_freefunc)(void *); + typedef void *(*OPENSSL_sk_copyfunc)(const void *); + OPENSSL_STACK *OPENSSL_sk_deep_copy(const OPENSSL_STACK *, + OPENSSL_sk_copyfunc c, + OPENSSL_sk_freefunc f); + ]] + _M.OPENSSL_sk_pop_free = C.OPENSSL_sk_pop_free + + _M.OPENSSL_sk_new_null = C.OPENSSL_sk_new_null + _M.OPENSSL_sk_push = C.OPENSSL_sk_push + _M.OPENSSL_sk_pop_free = C.OPENSSL_sk_pop_free + _M.OPENSSL_sk_num = C.OPENSSL_sk_num + _M.OPENSSL_sk_value = C.OPENSSL_sk_value + _M.OPENSSL_sk_dup = C.OPENSSL_sk_dup + _M.OPENSSL_sk_delete = C.OPENSSL_sk_delete + _M.OPENSSL_sk_free = C.OPENSSL_sk_free + _M.OPENSSL_sk_deep_copy = C.OPENSSL_sk_deep_copy +elseif OPENSSL_10 or BORINGSSL then + ffi.cdef [[ + typedef struct stack_st _STACK; + // i made this up + typedef struct stack_st OPENSSL_STACK; + + _STACK *sk_new_null(void); + void sk_pop_free(_STACK *st, void (*func) (void *)); + _STACK *sk_dup(_STACK *st); + void sk_free(_STACK *st); + + _STACK *sk_deep_copy(_STACK *, void *(*)(void *), void (*)(void *)); + ]] + + if BORINGSSL then -- indices are using size_t instead of int + ffi.cdef [[ + size_t sk_push(_STACK *st, void *data); + size_t sk_num(const _STACK *); + void *sk_value(const _STACK *, size_t); + void *sk_delete(_STACK *st, size_t loc); + ]] + else -- normal OpenSSL 1.0 + ffi.cdef [[ + int sk_push(_STACK *st, void *data); + int sk_num(const _STACK *); + void *sk_value(const _STACK *, int); + void *sk_delete(_STACK *st, int loc); + ]] + end + + _M.OPENSSL_sk_pop_free = C.sk_pop_free + + _M.OPENSSL_sk_new_null = C.sk_new_null + _M.OPENSSL_sk_push = function(...) return tonumber(C.sk_push(...)) end + _M.OPENSSL_sk_pop_free = C.sk_pop_free + _M.OPENSSL_sk_num = function(...) return tonumber(C.sk_num(...)) end + _M.OPENSSL_sk_value = C.sk_value + _M.OPENSSL_sk_delete = C.sk_delete + _M.OPENSSL_sk_dup = C.sk_dup + _M.OPENSSL_sk_free = C.sk_free + _M.OPENSSL_sk_deep_copy = C.sk_deep_copy +end + +return _M diff --git a/server/resty/openssl/include/x509/altname.lua b/server/resty/openssl/include/x509/altname.lua new file mode 100644 index 0000000..ce1db67 --- /dev/null +++ b/server/resty/openssl/include/x509/altname.lua @@ -0,0 +1,49 @@ +local GEN_OTHERNAME = 0 +local GEN_EMAIL = 1 +local GEN_DNS = 2 +local GEN_X400 = 3 +local GEN_DIRNAME = 4 +local GEN_EDIPARTY = 5 +local GEN_URI = 6 +local GEN_IPADD = 7 +local GEN_RID = 8 + +local default_types = { + OtherName = GEN_OTHERNAME, -- otherName + RFC822Name = GEN_EMAIL, -- email + RFC822 = GEN_EMAIL, + Email = GEN_EMAIL, + DNSName = GEN_DNS, -- dns + DNS = GEN_DNS, + X400 = GEN_X400, -- x400 + DirName = GEN_DIRNAME, -- dirName + EdiParty = GEN_EDIPARTY, -- EdiParty + UniformResourceIdentifier = GEN_URI, -- uri + URI = GEN_URI, + IPAddress = GEN_IPADD, -- ipaddr + IP = GEN_IPADD, + RID = GEN_RID, -- rid +} + +local literals = { + [GEN_OTHERNAME] = "OtherName", + [GEN_EMAIL] = "email", + [GEN_DNS] = "DNS", + [GEN_X400] = "X400", + [GEN_DIRNAME] = "DirName", + [GEN_EDIPARTY] = "EdiParty", + [GEN_URI] = "URI", + [GEN_IPADD] = "IP", + [GEN_RID] = "RID", +} + +local types = {} +for t, gid in pairs(default_types) do + types[t:lower()] = gid + types[t] = gid +end + +return { + types = types, + literals = literals, +}
\ No newline at end of file diff --git a/server/resty/openssl/include/x509/crl.lua b/server/resty/openssl/include/x509/crl.lua new file mode 100644 index 0000000..7870cd3 --- /dev/null +++ b/server/resty/openssl/include/x509/crl.lua @@ -0,0 +1,86 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" +require "resty.openssl.include.evp" +require "resty.openssl.include.objects" +require "resty.openssl.include.x509" +require "resty.openssl.include.stack" + +local asn1_macro = require "resty.openssl.include.asn1" + +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_11_OR_LATER = require("resty.openssl.version").OPENSSL_11_OR_LATER +local BORINGSSL_110 = require("resty.openssl.version").BORINGSSL_110 + +asn1_macro.declare_asn1_functions("X509_CRL", asn1_macro.has_new_ex) + +ffi.cdef [[ + X509_NAME *X509_CRL_get_issuer(const X509_CRL *crl); + int X509_CRL_set_issuer_name(X509_CRL *x, X509_NAME *name); + int X509_CRL_set_version(X509_CRL *x, long version); + + int X509_CRL_add_ext(X509_CRL *x, X509_EXTENSION *ex, int loc); + X509_EXTENSION *X509_CRL_get_ext(const X509_CRL *x, int loc); + int X509_CRL_get_ext_by_NID(const X509_CRL *x, int nid, int lastpos); + void *X509_CRL_get_ext_d2i(const X509_CRL *x, int nid, int *crit, int *idx); + + int X509_CRL_sign(X509_CRL *x, EVP_PKEY *pkey, const EVP_MD *md); + int X509_CRL_verify(X509_CRL *a, EVP_PKEY *r); + + int i2d_X509_CRL_bio(BIO *bp, X509_CRL *crl); + X509_CRL *d2i_X509_CRL_bio(BIO *bp, X509_CRL **crl); + int X509_CRL_add0_revoked(X509_CRL *crl, X509_REVOKED *rev); + + int X509_CRL_print(BIO *bio, X509_CRL *crl); + + int X509_CRL_get0_by_serial(X509_CRL *crl, + X509_REVOKED **ret, ASN1_INTEGER *serial); + int X509_CRL_get0_by_cert(X509_CRL *crl, X509_REVOKED **ret, X509 *x); + + //STACK_OF(X509_REVOKED) + OPENSSL_STACK *X509_CRL_get_REVOKED(X509_CRL *crl); + + int X509_CRL_get0_by_serial(X509_CRL *crl, + X509_REVOKED **ret, ASN1_INTEGER *serial); +]] + +if OPENSSL_11_OR_LATER then + ffi.cdef [[ + int X509_CRL_set1_lastUpdate(X509_CRL *x, const ASN1_TIME *tm); + int X509_CRL_set1_nextUpdate(X509_CRL *x, const ASN1_TIME *tm); + /*const*/ ASN1_TIME *X509_CRL_get0_lastUpdate(const X509_CRL *crl); + /*const*/ ASN1_TIME *X509_CRL_get0_nextUpdate(const X509_CRL *crl); + long X509_CRL_get_version(const X509_CRL *crl); + + X509_EXTENSION *X509_CRL_delete_ext(X509_CRL *x, int loc); + + int X509_CRL_get_signature_nid(const X509_CRL *crl); + ]] +end +if OPENSSL_10 or BORINGSSL_110 then + -- in openssl 1.0.x some getters are direct accessor to struct members (defiend by macros) + ffi.cdef [[ + typedef struct X509_crl_info_st { + ASN1_INTEGER *version; + X509_ALGOR *sig_alg; + X509_NAME *issuer; + ASN1_TIME *lastUpdate; + ASN1_TIME *nextUpdate; + // STACK_OF(X509_REVOKED) + OPENSSL_STACK *revoked; + // STACK_OF(X509_EXTENSION) + OPENSSL_STACK /* [0] */ *extensions; + ASN1_ENCODING enc; + } X509_CRL_INFO; + + // Note: this struct is trimmed + struct X509_crl_st { + /* actual signature */ + X509_CRL_INFO *crl; + // trimmed + } /* X509_CRL */ ; + + int X509_CRL_set_lastUpdate(X509_CRL *x, const ASN1_TIME *tm); + int X509_CRL_set_nextUpdate(X509_CRL *x, const ASN1_TIME *tm); + ]] +end diff --git a/server/resty/openssl/include/x509/csr.lua b/server/resty/openssl/include/x509/csr.lua new file mode 100644 index 0000000..44c4801 --- /dev/null +++ b/server/resty/openssl/include/x509/csr.lua @@ -0,0 +1,88 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" +require "resty.openssl.include.evp" +require "resty.openssl.include.objects" +require "resty.openssl.include.x509" +require "resty.openssl.include.stack" + +local asn1_macro = require "resty.openssl.include.asn1" + +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_11_OR_LATER = require("resty.openssl.version").OPENSSL_11_OR_LATER +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X +local BORINGSSL_110 = require("resty.openssl.version").BORINGSSL_110 + +asn1_macro.declare_asn1_functions("X509_REQ", asn1_macro.has_new_ex) + +ffi.cdef [[ + int X509_REQ_set_subject_name(X509_REQ *req, X509_NAME *name); + + EVP_PKEY *X509_REQ_get_pubkey(X509_REQ *req); + int X509_REQ_set_pubkey(X509_REQ *x, EVP_PKEY *pkey); + + int X509_REQ_set_version(X509_REQ *x, long version); + + int X509_REQ_get_attr_count(const X509_REQ *req); + + int X509_CRL_add_ext(X509_CRL *x, X509_EXTENSION *ex, int loc); + X509_EXTENSION *X509_CRL_get_ext(const X509_CRL *x, int loc); + int X509_CRL_get_ext_by_NID(const X509_CRL *x, int nid, int lastpos); + + int i2d_re_X509_REQ_tbs(X509_REQ *req, unsigned char **pp); + void X509_ATTRIBUTE_free(X509_ATTRIBUTE *a); + int X509_REQ_get_attr_by_NID(const X509_REQ *req, int nid, int lastpos); + X509_ATTRIBUTE *X509_REQ_delete_attr(X509_REQ *req, int loc); + + int *X509_REQ_get_extension_nids(void); + + int X509_REQ_sign(X509_REQ *x, EVP_PKEY *pkey, const EVP_MD *md); + int X509_REQ_verify(X509_REQ *a, EVP_PKEY *r); + + int i2d_X509_REQ_bio(BIO *bp, X509_REQ *req); + X509_REQ *d2i_X509_REQ_bio(BIO *bp, X509_REQ **req); + + // STACK_OF(X509_EXTENSION) + OPENSSL_STACK *X509_REQ_get_extensions(X509_REQ *req); + // STACK_OF(X509_EXTENSION) + int X509_REQ_add_extensions(X509_REQ *req, OPENSSL_STACK *exts); + + int X509_REQ_check_private_key(X509_REQ *x, EVP_PKEY *k); +]] + +if OPENSSL_11_OR_LATER then + ffi.cdef [[ + X509_NAME *X509_REQ_get_subject_name(const X509_REQ *req); + long X509_REQ_get_version(const X509_REQ *req); + + int X509_REQ_get_signature_nid(const X509_REQ *crl); + ]] +end +if OPENSSL_10 or BORINGSSL_110 then + ffi.cdef [[ + typedef struct X509_req_info_st { + ASN1_ENCODING enc; + ASN1_INTEGER *version; + X509_NAME *subject; + /*X509_PUBKEY*/ void *pubkey; + /* d=2 hl=2 l= 0 cons: cont: 00 */ + /*STACK_OF(X509_ATTRIBUTE)*/ OPENSSL_STACK *attributes; /* [ 0 ] */ + } X509_REQ_INFO; + + // Note: this struct is trimmed + typedef struct X509_req_st { + X509_REQ_INFO *req_info; + X509_ALGOR *sig_alg; + // trimmed + //ASN1_BIT_STRING *signature; + //int references; + } X509_REQ; + ]] +end + +if OPENSSL_3X then + ffi.cdef [[ + int X509_REQ_verify_ex(X509_REQ *a, EVP_PKEY *pkey, OSSL_LIB_CTX *libctx, + const char *propq); + ]] +end diff --git a/server/resty/openssl/include/x509/extension.lua b/server/resty/openssl/include/x509/extension.lua new file mode 100644 index 0000000..14b231e --- /dev/null +++ b/server/resty/openssl/include/x509/extension.lua @@ -0,0 +1,44 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" +require "resty.openssl.include.x509v3" +require "resty.openssl.include.x509" +local asn1_macro = require "resty.openssl.include.asn1" +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X + +asn1_macro.declare_asn1_functions("X509_EXTENSION") + +if OPENSSL_3X then + ffi.cdef [[ + struct v3_ext_ctx { + int flags; + X509 *issuer_cert; + X509 *subject_cert; + X509_REQ *subject_req; + X509_CRL *crl; + /*X509V3_CONF_METHOD*/ void *db_meth; + void *db; + EVP_PKEY *issuer_pkey; + }; + + int X509V3_set_issuer_pkey(X509V3_CTX *ctx, EVP_PKEY *pkey); + ]] + +else + ffi.cdef [[ + struct v3_ext_ctx { + int flags; + X509 *issuer_cert; + X509 *subject_cert; + X509_REQ *subject_req; + X509_CRL *crl; + /*X509V3_CONF_METHOD*/ void *db_meth; + void *db; + }; + ]] +end + +ffi.cdef [[ + int X509_EXTENSION_set_data(X509_EXTENSION *ex, ASN1_OCTET_STRING *data); + int X509_EXTENSION_set_object(X509_EXTENSION *ex, const ASN1_OBJECT *obj); +]]
\ No newline at end of file diff --git a/server/resty/openssl/include/x509/init.lua b/server/resty/openssl/include/x509/init.lua new file mode 100644 index 0000000..ec104ef --- /dev/null +++ b/server/resty/openssl/include/x509/init.lua @@ -0,0 +1,138 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" +require "resty.openssl.include.bio" +require "resty.openssl.include.pem" +require "resty.openssl.include.stack" +local asn1_macro = require "resty.openssl.include.asn1" + +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_11_OR_LATER = require("resty.openssl.version").OPENSSL_11_OR_LATER +local BORINGSSL_110 = require("resty.openssl.version").BORINGSSL_110 + +asn1_macro.declare_asn1_functions("X509", asn1_macro.has_new_ex) + +ffi.cdef [[ + int i2d_X509_bio(BIO *bp, X509 *x509); + X509 *d2i_X509_bio(BIO *bp, X509 **x509); + + // STACK_OF(X509) + OPENSSL_STACK *X509_chain_up_ref(OPENSSL_STACK *chain); + + int X509_sign(X509 *x, EVP_PKEY *pkey, const EVP_MD *md); + int X509_verify(X509 *a, EVP_PKEY *r); + + ASN1_TIME *X509_gmtime_adj(ASN1_TIME *s, long adj); + + int X509_add_ext(X509 *x, X509_EXTENSION *ex, int loc); + X509_EXTENSION *X509_get_ext(const X509 *x, int loc); + int X509_get_ext_by_NID(const X509 *x, int nid, int lastpos); + void *X509_get_ext_d2i(const X509 *x, int nid, int *crit, int *idx); + + int X509_EXTENSION_set_critical(X509_EXTENSION *ex, int crit); + int X509_EXTENSION_get_critical(const X509_EXTENSION *ex); + ASN1_OBJECT *X509_EXTENSION_get_object(X509_EXTENSION *ex); + ASN1_OCTET_STRING *X509_EXTENSION_get_data(X509_EXTENSION *ne); + X509_EXTENSION *X509V3_EXT_i2d(int ext_nid, int crit, void *ext_struc); + X509_EXTENSION *X509_EXTENSION_create_by_NID(X509_EXTENSION **ex, + int nid, int crit, + ASN1_OCTET_STRING *data); + + // needed by pkey + EVP_PKEY *d2i_PrivateKey_bio(BIO *bp, EVP_PKEY **a); + EVP_PKEY *d2i_PUBKEY_bio(BIO *bp, EVP_PKEY **a); + int i2d_PrivateKey_bio(BIO *bp, EVP_PKEY *pkey); + int i2d_PUBKEY_bio(BIO *bp, EVP_PKEY *pkey); + + EVP_PKEY *X509_get_pubkey(X509 *x); + int X509_set_pubkey(X509 *x, EVP_PKEY *pkey); + int X509_set_version(X509 *x, long version); + int X509_set_serialNumber(X509 *x, ASN1_INTEGER *serial); + + X509_NAME *X509_get_subject_name(const X509 *a); + int X509_set_subject_name(X509 *x, X509_NAME *name); + X509_NAME *X509_get_issuer_name(const X509 *a); + int X509_set_issuer_name(X509 *x, X509_NAME *name); + + int X509_pubkey_digest(const X509 *data, const EVP_MD *type, + unsigned char *md, unsigned int *len); + int X509_digest(const X509 *data, const EVP_MD *type, + unsigned char *md, unsigned int *len); + + const char *X509_verify_cert_error_string(long n); + int X509_verify_cert(X509_STORE_CTX *ctx); + + int X509_get_signature_nid(const X509 *x); + + unsigned char *X509_alias_get0(X509 *x, int *len); + unsigned char *X509_keyid_get0(X509 *x, int *len); + int X509_check_private_key(X509 *x, EVP_PKEY *k); +]] + +if OPENSSL_11_OR_LATER then + ffi.cdef [[ + int X509_up_ref(X509 *a); + + int X509_set1_notBefore(X509 *x, const ASN1_TIME *tm); + int X509_set1_notAfter(X509 *x, const ASN1_TIME *tm); + /*const*/ ASN1_TIME *X509_get0_notBefore(const X509 *x); + /*const*/ ASN1_TIME *X509_get0_notAfter(const X509 *x); + long X509_get_version(const X509 *x); + const ASN1_INTEGER *X509_get0_serialNumber(X509 *x); + + X509_EXTENSION *X509_delete_ext(X509 *x, int loc); + ]] +elseif OPENSSL_10 then + ffi.cdef [[ + // STACK_OF(X509_EXTENSION) + X509_EXTENSION *X509v3_delete_ext(OPENSSL_STACK *x, int loc); + ]] +end + +if OPENSSL_10 or BORINGSSL_110 then + -- in openssl 1.0.x some getters are direct accessor to struct members (defiend by macros) + ffi.cdef [[ + // crypto/x509/x509.h + typedef struct X509_val_st { + ASN1_TIME *notBefore; + ASN1_TIME *notAfter; + } X509_VAL; + + typedef struct X509_algor_st { + ASN1_OBJECT *algorithm; + ASN1_TYPE *parameter; + } X509_ALGOR; + + // Note: this struct is trimmed + typedef struct x509_cinf_st { + /*ASN1_INTEGER*/ void *version; + /*ASN1_INTEGER*/ void *serialNumber; + X509_ALGOR *signature; + X509_NAME *issuer; + X509_VAL *validity; + X509_NAME *subject; + /*X509_PUBKEY*/ void *key; + /*ASN1_BIT_STRING*/ void *issuerUID; /* [ 1 ] optional in v2 */ + /*ASN1_BIT_STRING*/ void *subjectUID; /* [ 2 ] optional in v2 */ + /*STACK_OF(X509_EXTENSION)*/ OPENSSL_STACK *extensions; /* [ 3 ] optional in v3 */ + // trimmed + // ASN1_ENCODING enc; + } X509_CINF; + // Note: this struct is trimmed + struct x509_st { + X509_CINF *cert_info; + // trimmed + } X509; + + int X509_set_notBefore(X509 *x, const ASN1_TIME *tm); + int X509_set_notAfter(X509 *x, const ASN1_TIME *tm); + ASN1_INTEGER *X509_get_serialNumber(X509 *x); + ]] +end + +if BORINGSSL_110 then + ffi.cdef [[ + ASN1_TIME *X509_get_notBefore(const X509 *x); + ASN1_TIME *X509_get_notAfter(const X509 *x); + ]] +end diff --git a/server/resty/openssl/include/x509/name.lua b/server/resty/openssl/include/x509/name.lua new file mode 100644 index 0000000..2f933ae --- /dev/null +++ b/server/resty/openssl/include/x509/name.lua @@ -0,0 +1,21 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" +require "resty.openssl.include.asn1" +require "resty.openssl.include.objects" +local asn1_macro = require "resty.openssl.include.asn1" + +asn1_macro.declare_asn1_functions("X509_NAME") + +ffi.cdef [[ + int X509_NAME_add_entry_by_OBJ(X509_NAME *name, const ASN1_OBJECT *obj, int type, + const unsigned char *bytes, int len, int loc, + int set); + + int X509_NAME_entry_count(const X509_NAME *name); + X509_NAME_ENTRY *X509_NAME_get_entry(X509_NAME *name, int loc); + ASN1_OBJECT *X509_NAME_ENTRY_get_object(const X509_NAME_ENTRY *ne); + ASN1_STRING * X509_NAME_ENTRY_get_data(const X509_NAME_ENTRY *ne); + int X509_NAME_get_index_by_OBJ(X509_NAME *name, const ASN1_OBJECT *obj, + int lastpos); +]]
\ No newline at end of file diff --git a/server/resty/openssl/include/x509/revoked.lua b/server/resty/openssl/include/x509/revoked.lua new file mode 100644 index 0000000..c6539c9 --- /dev/null +++ b/server/resty/openssl/include/x509/revoked.lua @@ -0,0 +1,17 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" +require "resty.openssl.include.asn1" +require "resty.openssl.include.objects" +local asn1_macro = require "resty.openssl.include.asn1" + +asn1_macro.declare_asn1_functions("X509_REVOKED") + +ffi.cdef [[ + int X509_REVOKED_set_serialNumber(X509_REVOKED *x, ASN1_INTEGER *serial); + int X509_REVOKED_set_revocationDate(X509_REVOKED *r, ASN1_TIME *tm); + int X509_REVOKED_add_ext(X509_REVOKED *x, X509_EXTENSION *ex, int loc); + + const ASN1_INTEGER *X509_REVOKED_get0_serialNumber(const X509_REVOKED *r); + const ASN1_TIME *X509_REVOKED_get0_revocationDate(const X509_REVOKED *r); +]]
\ No newline at end of file diff --git a/server/resty/openssl/include/x509_vfy.lua b/server/resty/openssl/include/x509_vfy.lua new file mode 100644 index 0000000..d783d19 --- /dev/null +++ b/server/resty/openssl/include/x509_vfy.lua @@ -0,0 +1,108 @@ +local ffi = require "ffi" +local C = ffi.C + +require "resty.openssl.include.ossl_typ" +require "resty.openssl.include.stack" +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_11_OR_LATER = require("resty.openssl.version").OPENSSL_11_OR_LATER +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X +local BORINGSSL_110 = require("resty.openssl.version").BORINGSSL_110 + +ffi.cdef [[ + X509_STORE *X509_STORE_new(void); + void X509_STORE_free(X509_STORE *v); + /* int X509_STORE_lock(X509_STORE *ctx); + int X509_STORE_unlock(X509_STORE *ctx); + int X509_STORE_up_ref(X509_STORE *v); + // STACK_OF(X509_OBJECT) + OPENSSL_STACK *X509_STORE_get0_objects(X509_STORE *v);*/ + + int X509_STORE_add_cert(X509_STORE *ctx, X509 *x); + int X509_STORE_add_crl(X509_STORE *ctx, X509_CRL *x); + int X509_STORE_load_locations(X509_STORE *ctx, + const char *file, const char *dir); + int X509_STORE_set_default_paths(X509_STORE *ctx); + int X509_STORE_set_flags(X509_STORE *ctx, unsigned long flags); + int X509_STORE_set_depth(X509_STORE *store, int depth); + int X509_STORE_set_purpose(X509_STORE *ctx, int purpose); + + X509_STORE_CTX *X509_STORE_CTX_new(void); + void X509_STORE_CTX_free(X509_STORE_CTX *ctx); + // STACK_OF(X509) + int X509_STORE_CTX_init(X509_STORE_CTX *ctx, X509_STORE *store, + X509 *x509, OPENSSL_STACK *chain); + + int X509_STORE_CTX_get_error(X509_STORE_CTX *ctx); + + int X509_STORE_CTX_set_default(X509_STORE_CTX *ctx, const char *name); + + int X509_PURPOSE_get_by_sname(char *sname); + X509_PURPOSE *X509_PURPOSE_get0(int idx); + int X509_PURPOSE_get_id(const X509_PURPOSE *xp); +]] + +local _M = { + verify_flags = { + X509_V_FLAG_CB_ISSUER_CHECK = 0x0, -- Deprecated + X509_V_FLAG_USE_CHECK_TIME = 0x2, + X509_V_FLAG_CRL_CHECK = 0x4, + X509_V_FLAG_CRL_CHECK_ALL = 0x8, + X509_V_FLAG_IGNORE_CRITICAL = 0x10, + X509_V_FLAG_X509_STRICT = 0x20, + X509_V_FLAG_ALLOW_PROXY_CERTS = 0x40, + X509_V_FLAG_POLICY_CHECK = 0x80, + X509_V_FLAG_EXPLICIT_POLICY = 0x100, + X509_V_FLAG_INHIBIT_ANY = 0x200, + X509_V_FLAG_INHIBIT_MAP = 0x400, + X509_V_FLAG_NOTIFY_POLICY = 0x800, + X509_V_FLAG_EXTENDED_CRL_SUPPORT = 0x1000, + X509_V_FLAG_USE_DELTAS = 0x2000, + X509_V_FLAG_CHECK_SS_SIGNATURE = 0x4000, + X509_V_FLAG_TRUSTED_FIRST = 0x8000, + X509_V_FLAG_SUITEB_128_LOS_ONLY = 0x10000, + X509_V_FLAG_SUITEB_192_LOS = 0x20000, + X509_V_FLAG_SUITEB_128_LOS = 0x30000, + X509_V_FLAG_PARTIAL_CHAIN = 0x80000, + X509_V_FLAG_NO_ALT_CHAINS = 0x100000, + X509_V_FLAG_NO_CHECK_TIME = 0x200000, + }, +} + +if OPENSSL_10 or BORINGSSL_110 then + ffi.cdef [[ + // STACK_OF(X509) + OPENSSL_STACK *X509_STORE_CTX_get_chain(X509_STORE_CTX *ctx); + ]]; + _M.X509_STORE_CTX_get0_chain = C.X509_STORE_CTX_get_chain +elseif OPENSSL_11_OR_LATER then + ffi.cdef [[ + // STACK_OF(X509) + OPENSSL_STACK *X509_STORE_CTX_get0_chain(X509_STORE_CTX *ctx); + ]]; + _M.X509_STORE_CTX_get0_chain = C.X509_STORE_CTX_get0_chain +end + +if OPENSSL_3X then + ffi.cdef [[ + X509_STORE_CTX *X509_STORE_CTX_new_ex(OSSL_LIB_CTX *libctx, const char *propq); + + int X509_STORE_set_default_paths_ex(X509_STORE *ctx, OSSL_LIB_CTX *libctx, + const char *propq); + /* int X509_STORE_load_file_ex(X509_STORE *ctx, const char *file, + OSSL_LIB_CTX *libctx, const char *propq); + int X509_STORE_load_store_ex(X509_STORE *ctx, const char *uri, + OSSL_LIB_CTX *libctx, const char *propq); */ + int X509_STORE_load_locations_ex(X509_STORE *ctx, const char *file, + const char *dir, OSSL_LIB_CTX *libctx, + const char *propq); + ]] + _M.X509_STORE_set_default_paths = function(...) return C.X509_STORE_set_default_paths_ex(...) end + _M.X509_STORE_load_locations = function(...) return C.X509_STORE_load_locations_ex(...) end +else + _M.X509_STORE_set_default_paths = function(s) return C.X509_STORE_set_default_paths(s) end + _M.X509_STORE_load_locations = function(s, file, dir) return C.X509_STORE_load_locations(s, file, dir) end +end + + +return _M + diff --git a/server/resty/openssl/include/x509v3.lua b/server/resty/openssl/include/x509v3.lua new file mode 100644 index 0000000..6882c6e --- /dev/null +++ b/server/resty/openssl/include/x509v3.lua @@ -0,0 +1,108 @@ +local ffi = require "ffi" + +require "resty.openssl.include.ossl_typ" +require "resty.openssl.include.stack" +local asn1_macro = require "resty.openssl.include.asn1" + +ffi.cdef [[ + // STACK_OF(OPENSSL_STRING) + OPENSSL_STACK *X509_get1_ocsp(X509 *x); + void X509_email_free(OPENSSL_STACK *sk); + void X509V3_set_nconf(X509V3_CTX *ctx, CONF *conf); + + typedef struct EDIPartyName_st EDIPARTYNAME; + + typedef struct otherName_st OTHERNAME; + + typedef struct GENERAL_NAME_st { + int type; + union { + char *ptr; + OTHERNAME *otherName; /* otherName */ + ASN1_IA5STRING *rfc822Name; + ASN1_IA5STRING *dNSName; + ASN1_TYPE *x400Address; + X509_NAME *directoryName; + EDIPARTYNAME *ediPartyName; + ASN1_IA5STRING *uniformResourceIdentifier; + ASN1_OCTET_STRING *iPAddress; + ASN1_OBJECT *registeredID; + /* Old names */ + ASN1_OCTET_STRING *ip; /* iPAddress */ + X509_NAME *dirn; /* dirn */ + ASN1_IA5STRING *ia5; /* rfc822Name, dNSName, + * uniformResourceIdentifier */ + ASN1_OBJECT *rid; /* registeredID */ + ASN1_TYPE *other; /* x400Address */ + } d; + } GENERAL_NAME; + + // STACK_OF(GENERAL_NAME) + typedef struct stack_st GENERAL_NAMES; + + // STACK_OF(X509_EXTENSION) + int X509V3_add1_i2d(OPENSSL_STACK **x, int nid, void *value, + int crit, unsigned long flags); + void *X509V3_EXT_d2i(X509_EXTENSION *ext); + X509_EXTENSION *X509V3_EXT_i2d(int ext_nid, int crit, void *ext_struc); + int X509V3_EXT_print(BIO *out, X509_EXTENSION *ext, unsigned long flag, + int indent); + + int X509_add1_ext_i2d(X509 *x, int nid, void *value, int crit, + unsigned long flags); + // although the struct has plural form, it's not a stack + typedef struct BASIC_CONSTRAINTS_st { + int ca; + ASN1_INTEGER *pathlen; + } BASIC_CONSTRAINTS; + + void X509V3_set_ctx(X509V3_CTX *ctx, X509 *issuer, X509 *subject, + X509_REQ *req, X509_CRL *crl, int flags); + + X509_EXTENSION *X509V3_EXT_nconf_nid(CONF *conf, X509V3_CTX *ctx, int ext_nid, + const char *value); + X509_EXTENSION *X509V3_EXT_nconf(CONF *conf, X509V3_CTX *ctx, const char *name, + const char *value); + int X509V3_EXT_print(BIO *out, X509_EXTENSION *ext, unsigned long flag, + int indent); + + void *X509V3_get_d2i(const OPENSSL_STACK *x, int nid, int *crit, int *idx); + + int X509v3_get_ext_by_NID(const OPENSSL_STACK *x, + int nid, int lastpos); + + X509_EXTENSION *X509v3_get_ext(const OPENSSL_STACK *x, int loc); + + // STACK_OF(ACCESS_DESCRIPTION) + typedef struct stack_st AUTHORITY_INFO_ACCESS; + + typedef struct ACCESS_DESCRIPTION_st { + ASN1_OBJECT *method; + GENERAL_NAME *location; + } ACCESS_DESCRIPTION; + + typedef struct DIST_POINT_NAME_st { + int type; + union { + GENERAL_NAMES *fullname; + // STACK_OF(X509_NAME_ENTRY) + OPENSSL_STACK *relativename; + } name; + /* If relativename then this contains the full distribution point name */ + X509_NAME *dpname; + } DIST_POINT_NAME; + + typedef struct DIST_POINT_st { + DIST_POINT_NAME *distpoint; + ASN1_BIT_STRING *reasons; + GENERAL_NAMES *CRLissuer; + int dp_reasons; + } DIST_POINT; + +]] + +asn1_macro.declare_asn1_functions("GENERAL_NAME") +asn1_macro.declare_asn1_functions("BASIC_CONSTRAINTS") +asn1_macro.declare_asn1_functions("AUTHORITY_INFO_ACCESS") -- OCSP responder and CA +asn1_macro.declare_asn1_functions("ACCESS_DESCRIPTION") +asn1_macro.declare_asn1_functions("DIST_POINT") -- CRL distribution points diff --git a/server/resty/openssl/kdf.lua b/server/resty/openssl/kdf.lua new file mode 100644 index 0000000..62188bc --- /dev/null +++ b/server/resty/openssl/kdf.lua @@ -0,0 +1,388 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc +local ffi_str = ffi.string + +require("resty.openssl.objects") +require("resty.openssl.include.evp.md") +-- used by legacy EVP_PKEY_derive interface +require("resty.openssl.include.evp.pkey") +local kdf_macro = require "resty.openssl.include.evp.kdf" +local ctx_lib = require "resty.openssl.ctx" +local format_error = require("resty.openssl.err").format_error +local version_num = require("resty.openssl.version").version_num +local version_text = require("resty.openssl.version").version_text +local BORINGSSL = require("resty.openssl.version").BORINGSSL +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X +local ctypes = require "resty.openssl.auxiliary.ctypes" + +--[[ +https://wiki.openssl.org/index.php/EVP_Key_Derivation + +OpenSSL 1.0.2 and above provides PBKDF2 by way of PKCS5_PBKDF2_HMAC and PKCS5_PBKDF2_HMAC_SHA1. +OpenSSL 1.1.0 and above additionally provides HKDF and TLS1 PRF KDF by way of EVP_PKEY_derive and Scrypt by way of EVP_PBE_scrypt +OpenSSL 1.1.1 and above additionally provides Scrypt by way of EVP_PKEY_derive. +OpenSSL 3.0 additionally provides Single Step KDF, SSH KDF, PBKDF2, Scrypt, HKDF, ANSI X9.42 KDF, ANSI X9.63 KDF and TLS1 PRF KDF by way of EVP_KDF. +From OpenSSL 3.0 the recommended way of performing key derivation is to use the EVP_KDF functions. If compatibility with OpenSSL 1.1.1 is required then a limited set of KDFs can be used via EVP_PKEY_derive. +]] + +local NID_id_pbkdf2 = -1 +local NID_id_scrypt = -2 +local NID_tls1_prf = -3 +local NID_hkdf = -4 +if version_num >= 0x10002000 then + NID_id_pbkdf2 = C.OBJ_txt2nid("PBKDF2") + assert(NID_id_pbkdf2 > 0) +end +if version_num >= 0x10100000 and not BORINGSSL then + NID_hkdf = C.OBJ_txt2nid("HKDF") + assert(NID_hkdf > 0) + NID_tls1_prf = C.OBJ_txt2nid("TLS1-PRF") + assert(NID_tls1_prf > 0) + -- we use EVP_PBE_scrypt to do scrypt, so this is supported >= 1.1.0 + NID_id_scrypt = C.OBJ_txt2nid("id-scrypt") + assert(NID_id_scrypt > 0) +end + +local _M = { + HKDEF_MODE_EXTRACT_AND_EXPAND = kdf_macro.EVP_PKEY_HKDEF_MODE_EXTRACT_AND_EXPAND, + HKDEF_MODE_EXTRACT_ONLY = kdf_macro.EVP_PKEY_HKDEF_MODE_EXTRACT_ONLY, + HKDEF_MODE_EXPAND_ONLY = kdf_macro.EVP_PKEY_HKDEF_MODE_EXPAND_ONLY, + + PBKDF2 = NID_id_pbkdf2, + SCRYPT = NID_id_scrypt, + TLS1_PRF = NID_tls1_prf, + HKDF = NID_hkdf, +} + +local type_literals = { + [NID_id_pbkdf2] = "PBKDF2", + [NID_id_scrypt] = "scrypt", + [NID_tls1_prf] = "TLS-1PRF", + [NID_hkdf] = "HKDF", +} + +local TYPE_NUMBER = 0x1 +local TYPE_STRING = 0x2 + +local function check_options(opt, nid, field, typ, is_optional, required_only_if_nid) + local v = opt[field] + if not v then + if is_optional or (required_only_if_nid and required_only_if_nid ~= nid) then + return typ == TYPE_NUMBER and 0 or nil + else + return nil, "\"" .. field .. "\" must be set" + end + end + + if typ == TYPE_NUMBER then + v = tonumber(v) + if not typ then + return nil, "except a number as \"" .. field .. "\"" + end + elseif typ == TYPE_STRING then + if type(v) ~= "string" then + return nil, "except a string as \"" .. field .. "\"" + end + else + error("don't known how to check " .. typ, 2) + end + + return v +end + +local function check_hkdf_options(opt) + local mode = opt.hkdf_mode + if not mode or version_num < 0x10101000 then + mode = _M.HKDEF_MODE_EXTRACT_AND_EXPAND + end + + if mode == _M.HKDEF_MODE_EXTRACT_AND_EXPAND and ( + not opt.salt or not opt.hkdf_info) then + return '""salt" and "hkdf_info" are required for EXTRACT_AND_EXPAND mode' + elseif mode == _M.HKDEF_MODE_EXTRACT_ONLY and not opt.salt then + return '"salt" is required for EXTRACT_ONLY mode' + elseif mode == _M.EVP_PKEY_HKDEF_MODE_EXPAND_ONLY and not opt.hkdf_info then + return '"hkdf_info" is required for EXPAND_ONLY mode' + end + + return nil +end + +local options_schema = { + outlen = { TYPE_NUMBER }, + pass = { TYPE_STRING, true }, + salt = { TYPE_STRING, true }, + md = { TYPE_STRING, true }, + -- pbkdf2 only + pbkdf2_iter = { TYPE_NUMBER, true }, + -- hkdf only + hkdf_key = { TYPE_STRING, nil, NID_hkdf }, + hkdf_mode = { TYPE_NUMBER, true }, + hkdf_info = { TYPE_STRING, true }, + -- tls1-prf + tls1_prf_secret = { TYPE_STRING, nil, NID_tls1_prf }, + tls1_prf_seed = { TYPE_STRING, nil, NID_tls1_prf }, + -- scrypt only + scrypt_maxmem = { TYPE_NUMBER, true }, + scrypt_N = { TYPE_NUMBER, nil, NID_id_scrypt }, + scrypt_r = { TYPE_NUMBER, nil, NID_id_scrypt }, + scrypt_p = { TYPE_NUMBER, nil, NID_id_scrypt }, +} + +local outlen = ctypes.ptr_of_uint64() + +function _M.derive(options) + local typ = options.type + if not typ then + return nil, "kdf.derive: \"type\" must be set" + elseif type(typ) ~= "number" then + return nil, "kdf.derive: expect a number as \"type\"" + end + + if typ <= 0 then + return nil, "kdf.derive: kdf type " .. (type_literals[typ] or tostring(typ)) .. + " not supported in " .. version_text + end + + for k, v in pairs(options_schema) do + local v, err = check_options(options, typ, k, unpack(v)) + if err then + return nil, "kdf.derive: " .. err + end + options[k] = v + end + + if typ == NID_hkdf then + local err = check_hkdf_options(options) + if err then + return nil, "kdf.derive: " .. err + end + end + + local salt_len = 0 + if options.salt then + salt_len = #options.salt + end + local pass_len = 0 + if options.pass then + pass_len = #options.pass + end + + local md + if OPENSSL_3X then + md = C.EVP_MD_fetch(ctx_lib.get_libctx(), options.md or 'sha1', options.properties) + else + md = C.EVP_get_digestbyname(options.md or 'sha1') + end + if md == nil then + return nil, string.format("kdf.derive: invalid digest type \"%s\"", md) + end + + local buf = ctypes.uchar_array(options.outlen) + + -- begin legacay low level routines + local code + if typ == NID_id_pbkdf2 then + -- make openssl 1.0.2 happy + if version_num < 0x10100000 and not options.pass then + options.pass = "" + pass_len = 0 + end + -- https://www.openssl.org/docs/man1.1.0/man3/PKCS5_PBKDF2_HMAC.html + local iter = options.pbkdf2_iter + if iter < 1 then + iter = 1 + end + code = C.PKCS5_PBKDF2_HMAC( + options.pass, pass_len, + options.salt, salt_len, iter, + md, options.outlen, buf + ) + elseif typ == NID_id_scrypt then + code = C.EVP_PBE_scrypt( + options.pass, pass_len, + options.salt, salt_len, + options.scrypt_N, options.scrypt_r, options.scrypt_p, options.scrypt_maxmem, + buf, options.outlen + ) + elseif typ ~= NID_tls1_prf and typ ~= NID_hkdf then + return nil, string.format("kdf.derive: unknown type %d", typ) + end + if code then + if code ~= 1 then + return nil, format_error("kdf.derive") + else + return ffi_str(buf, options.outlen) + end + end + -- end legacay low level routines + + -- begin EVP_PKEY_derive routines + outlen[0] = options.outlen + + local ctx = C.EVP_PKEY_CTX_new_id(typ, nil) + if ctx == nil then + return nil, format_error("kdf.derive: EVP_PKEY_CTX_new_id") + end + ffi_gc(ctx, C.EVP_PKEY_CTX_free) + if C.EVP_PKEY_derive_init(ctx) ~= 1 then + return nil, format_error("kdf.derive: EVP_PKEY_derive_init") + end + + if typ == NID_tls1_prf then + if kdf_macro.EVP_PKEY_CTX_set_tls1_prf_md(ctx, md) ~= 1 then + return nil, format_error("kdf.derive: EVP_PKEY_CTX_set_tls1_prf_md") + end + if kdf_macro.EVP_PKEY_CTX_set1_tls1_prf_secret(ctx, options.tls1_prf_secret) ~= 1 then + return nil, format_error("kdf.derive: EVP_PKEY_CTX_set1_tls1_prf_secret") + end + if kdf_macro.EVP_PKEY_CTX_add1_tls1_prf_seed(ctx, options.tls1_prf_seed) ~= 1 then + return nil, format_error("kdf.derive: EVP_PKEY_CTX_add1_tls1_prf_seed") + end + elseif typ == NID_hkdf then + if kdf_macro.EVP_PKEY_CTX_set_hkdf_md(ctx, md) ~= 1 then + return nil, format_error("kdf.derive: EVP_PKEY_CTX_set_hkdf_md") + end + if options.salt and + kdf_macro.EVP_PKEY_CTX_set1_hkdf_salt(ctx, options.salt) ~= 1 then + return nil, format_error("kdf.derive: EVP_PKEY_CTX_set1_hkdf_salt") + end + if options.hkdf_key and + kdf_macro.EVP_PKEY_CTX_set1_hkdf_key(ctx, options.hkdf_key) ~= 1 then + return nil, format_error("kdf.derive: EVP_PKEY_CTX_set1_hkdf_key") + end + if options.hkdf_info and + kdf_macro.EVP_PKEY_CTX_add1_hkdf_info(ctx, options.hkdf_info) ~= 1 then + return nil, format_error("kdf.derive: EVP_PKEY_CTX_add1_hkdf_info") + end + if options.hkdf_mode then + if version_num >= 0x10101000 then + if kdf_macro.EVP_PKEY_CTX_set_hkdf_mode(ctx, options.hkdf_mode) ~= 1 then + return nil, format_error("kdf.derive: EVP_PKEY_CTX_set_hkdf_mode") + end + if options.hkdf_mode == _M.HKDEF_MODE_EXTRACT_ONLY then + local md_size = OPENSSL_3X and C.EVP_MD_get_size(md) or C.EVP_MD_size(md) + if options.outlen ~= md_size then + options.outlen = md_size + ngx.log(ngx.WARN, "hkdf_mode EXTRACT_ONLY outputs fixed length of ", md_size, + " key, ignoring options.outlen") + end + outlen[0] = md_size + buf = ctypes.uchar_array(md_size) + end + else + ngx.log(ngx.WARN, "hkdf_mode is not effective in ", version_text) + end + end + else + return nil, string.format("kdf.derive: unknown type %d", typ) + end + code = C.EVP_PKEY_derive(ctx, buf, outlen) + if code == -2 then + return nil, "kdf.derive: operation is not supported by the public key algorithm" + end + -- end EVP_PKEY_derive routines + + return ffi_str(buf, options.outlen) +end + +if not OPENSSL_3X then + return _M +end + +_M.derive_legacy = _M.derive +_M.derive = nil + +-- OPENSSL 3.0 style API +local param_lib = require "resty.openssl.param" +local SIZE_MAX = ctypes.SIZE_MAX + +local mt = {__index = _M} + +local kdf_ctx_ptr_ct = ffi.typeof('EVP_KDF_CTX*') + +function _M.new(typ, properties) + local algo = C.EVP_KDF_fetch(ctx_lib.get_libctx(), typ, properties) + if algo == nil then + return nil, format_error(string.format("mac.new: invalid mac type \"%s\"", typ)) + end + + local ctx = C.EVP_KDF_CTX_new(algo) + if ctx == nil then + return nil, "mac.new: failed to create EVP_MAC_CTX" + end + ffi_gc(ctx, C.EVP_KDF_CTX_free) + + local buf + local buf_size = tonumber(C.EVP_KDF_CTX_get_kdf_size(ctx)) + if buf_size == SIZE_MAX then -- no fixed size + buf_size = nil + else + buf = ctypes.uchar_array(buf_size) + end + + return setmetatable({ + ctx = ctx, + algo = algo, + buf = buf, + buf_size = buf_size, + }, mt), nil +end + +function _M.istype(l) + return l and l.ctx and ffi.istype(kdf_ctx_ptr_ct, l.ctx) +end + +function _M:get_provider_name() + local p = C.EVP_KDF_get0_provider(self.algo) + if p == nil then + return nil + end + return ffi_str(C.OSSL_PROVIDER_get0_name(p)) +end + +_M.settable_params, _M.set_params, _M.gettable_params, _M.get_param = param_lib.get_params_func("EVP_KDF_CTX") + +function _M:derive(outlen, options, options_count) + if not _M.istype(self) then + return _M.derive_legacy(self) + end + + if self.buf_size and outlen then + return nil, string.format("kdf:derive: this KDF has fixed output size %d, ".. + "it can't be set manually", self.buf_size) + end + + outlen = self.buf_size or outlen + local buf = self.buf or ctypes.uchar_array(outlen) + + if options_count then + options_count = options_count - 1 + else + options_count = 0 + for k, v in pairs(options) do options_count = options_count + 1 end + end + + local param, err + if options_count > 0 then + local schema = self:settable_params(true) -- raw schema + param, err = param_lib.construct(options, nil, schema) + if err then + return nil, "kdf:derive: " .. err + end + end + + if C.EVP_KDF_derive(self.ctx, buf, outlen, param) ~= 1 then + return nil, format_error("kdf:derive") + end + + return ffi_str(buf, outlen) +end + +function _M:reset() + C.EVP_KDF_CTX_reset(self.ctx) + return true +end + +return _M
\ No newline at end of file diff --git a/server/resty/openssl/mac.lua b/server/resty/openssl/mac.lua new file mode 100644 index 0000000..65f5e38 --- /dev/null +++ b/server/resty/openssl/mac.lua @@ -0,0 +1,96 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc +local ffi_str = ffi.string + +require "resty.openssl.include.evp.mac" +local param_lib = require "resty.openssl.param" +local ctx_lib = require "resty.openssl.ctx" +local ctypes = require "resty.openssl.auxiliary.ctypes" +local format_error = require("resty.openssl.err").format_error +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X + +local _M = {} +local mt = {__index = _M} + +local mac_ctx_ptr_ct = ffi.typeof('EVP_MAC_CTX*') +local param_types = { + cipher = param_lib.OSSL_PARAM_UTF8_STRING, + digest = param_lib.OSSL_PARAM_UTF8_STRING, +} +local params = {} + +function _M.new(key, typ, cipher, digest, properties) + if not OPENSSL_3X then + return false, "EVP_MAC is only supported from OpenSSL 3.0" + end + + local algo = C.EVP_MAC_fetch(ctx_lib.get_libctx(), typ, properties) + if algo == nil then + return nil, format_error(string.format("mac.new: invalid mac type \"%s\"", typ)) + end + + local ctx = C.EVP_MAC_CTX_new(algo) + if ctx == nil then + return nil, "mac.new: failed to create EVP_MAC_CTX" + end + ffi_gc(ctx, C.EVP_MAC_CTX_free) + + params.digest = digest + params.cipher = cipher + local p = param_lib.construct(params, 2, param_types) + + local code = C.EVP_MAC_init(ctx, key, #key, p) + if code ~= 1 then + return nil, format_error(string.format("mac.new: invalid cipher or digest type")) + end + + local md_size = C.EVP_MAC_CTX_get_mac_size(ctx) + + return setmetatable({ + ctx = ctx, + algo = algo, + buf = ctypes.uchar_array(md_size), + buf_size = md_size, + }, mt), nil +end + +function _M.istype(l) + return l and l.ctx and ffi.istype(mac_ctx_ptr_ct, l.ctx) +end + +function _M:get_provider_name() + local p = C.EVP_MAC_get0_provider(self.algo) + if p == nil then + return nil + end + return ffi_str(C.OSSL_PROVIDER_get0_name(p)) +end + +_M.settable_params, _M.set_params, _M.gettable_params, _M.get_param = param_lib.get_params_func("EVP_MAC_CTX") + +function _M:update(...) + for _, s in ipairs({...}) do + if C.EVP_MAC_update(self.ctx, s, #s) ~= 1 then + return false, format_error("digest:update") + end + end + return true, nil +end + +function _M:final(s) + if s then + local _, err = self:update(s) + if err then + return nil, err + end + end + + local length = ctypes.ptr_of_size_t() + if C.EVP_MAC_final(self.ctx, self.buf, length, self.buf_size) ~= 1 then + return nil, format_error("digest:final: EVP_MAC_final") + end + return ffi_str(self.buf, length[0]) +end + +return _M diff --git a/server/resty/openssl/objects.lua b/server/resty/openssl/objects.lua new file mode 100644 index 0000000..bd02a38 --- /dev/null +++ b/server/resty/openssl/objects.lua @@ -0,0 +1,74 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_str = ffi.string +local ffi_sizeof = ffi.sizeof + +require "resty.openssl.include.objects" +require "resty.openssl.include.err" + +local buf = ffi.new('char[?]', 100) + +local function obj2table(obj) + local nid = C.OBJ_obj2nid(obj) + + local len = C.OBJ_obj2txt(buf, ffi_sizeof(buf), obj, 1) + local oid = ffi_str(buf, len) + + return { + id = oid, + nid = nid, + sn = ffi_str(C.OBJ_nid2sn(nid)), + ln = ffi_str(C.OBJ_nid2ln(nid)), + } +end + +local function nid2table(nid) + return obj2table(C.OBJ_nid2obj(nid)) +end + +local function txt2nid(txt) + if type(txt) ~= "string" then + return nil, "objects.txt2nid: expect a string at #1" + end + local nid = C.OBJ_txt2nid(txt) + if nid == 0 then + -- clean up error occurs during OBJ_txt2nid + C.ERR_clear_error() + return nil, "objects.txt2nid: invalid NID text " .. txt + end + return nid +end + +local function txtnid2nid(txt_nid) + local nid + if type(txt_nid) == "string" then + nid = C.OBJ_txt2nid(txt_nid) + if nid == 0 then + -- clean up error occurs during OBJ_txt2nid + C.ERR_clear_error() + return nil, "objects.txtnid2nid: invalid NID text " .. txt_nid + end + elseif type(txt_nid) == "number" then + nid = txt_nid + else + return nil, "objects.txtnid2nid: expect string or number at #1" + end + return nid +end + +local function find_sigid_algs(nid) + local out = ffi.new("int[0]") + if C.OBJ_find_sigid_algs(nid, out, nil) == 0 then + return 0, "objects.find_sigid_algs: invalid sigid " .. nid + end + return tonumber(out[0]) +end + +return { + obj2table = obj2table, + nid2table = nid2table, + txt2nid = txt2nid, + txtnid2nid = txtnid2nid, + find_sigid_algs = find_sigid_algs, + create = C.OBJ_create, +}
\ No newline at end of file diff --git a/server/resty/openssl/param.lua b/server/resty/openssl/param.lua new file mode 100644 index 0000000..2c8dcea --- /dev/null +++ b/server/resty/openssl/param.lua @@ -0,0 +1,322 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_new = ffi.new +local ffi_str = ffi.string +local ffi_cast = ffi.cast + +require "resty.openssl.include.param" +local format_error = require("resty.openssl.err").format_error +local bn_lib = require("resty.openssl.bn") +local null = require("resty.openssl.auxiliary.ctypes").null + +local OSSL_PARAM_INTEGER = 1 +local OSSL_PARAM_UNSIGNED_INTEGER = 2 +local OSSL_PARAM_REAL = 3 +local OSSL_PARAM_UTF8_STRING = 4 +local OSSL_PARAM_OCTET_STRING = 5 +local OSSL_PARAM_UTF8_PTR = 6 +local OSSL_PARAM_OCTET_PTR = 7 + +local alter_type_key = {} +local buf_param_key = {} + +local function construct(buf_t, length, types_map, types_size) + if not length then + length = 0 + for k, v in pairs(buf_t) do length = length + 1 end + end + + local params = ffi_new("OSSL_PARAM[?]", length + 1) + + local i = 0 + local buf_param + for key, value in pairs(buf_t) do + local typ = types_map[key] + if not typ then + return nil, "param:construct: unknown key \"" .. key .. "\"" + end + local param, buf, size + if value == null then -- out + value = nil + size = types_size and types_size[key] or 100 + if typ == OSSL_PARAM_UTF8_STRING or typ == OSSL_PARAM_OCTET_STRING then + buf = ffi_new("char[?]", size) + end + else + local numeric = type(value) == "number" + if (numeric and typ >= OSSL_PARAM_UTF8_STRING) or + (not numeric and typ <= OSSL_PARAM_UNSIGNED_INTEGER) then + local alter_typ = types_map[alter_type_key] and types_map[alter_type_key][key] + if alter_typ and ((numeric and alter_typ <= OSSL_PARAM_UNSIGNED_INTEGER) or + (not numeric and alter_typ >= OSSL_PARAM_UTF8_STRING)) then + typ = alter_typ + else + return nil, "param:construct: key \"" .. key .. "\" can't be a " .. type(value) + end + end + end + + if typ == "bn" then -- out only + buf = ffi_new("char[?]", size) + param = C.OSSL_PARAM_construct_BN(key, buf, size) + buf_param = buf_param or {} + buf_param[key] = param + elseif typ == OSSL_PARAM_INTEGER then + buf = value and ffi_new("int[1]", value) or ffi_new("int[1]") + param = C.OSSL_PARAM_construct_int(key, buf) + elseif typ == OSSL_PARAM_UNSIGNED_INTEGER then + buf = value and ffi_new("unsigned int[1]", value) or + ffi_new("unsigned int[1]") + param = C.OSSL_PARAM_construct_uint(key, buf) + elseif typ == OSSL_PARAM_UTF8_STRING then + buf = value and ffi_cast("char *", value) or buf + param = C.OSSL_PARAM_construct_utf8_string(key, buf, value and #value or size) + elseif typ == OSSL_PARAM_OCTET_STRING then + buf = value and ffi_cast("char *", value) or buf + param = C.OSSL_PARAM_construct_octet_string(key, ffi_cast("void*", buf), + value and #value or size) + elseif typ == OSSL_PARAM_UTF8_PTR then + buf = ffi_new("char*[1]") + param = C.OSSL_PARAM_construct_utf8_ptr(key, buf, 0) + elseif typ == OSSL_PARAM_OCTET_PTR then + buf = ffi_new("char*[1]") + param = C.OSSL_PARAM_construct_octet_ptr(key, ffi_cast("void**", buf), 0) + else + error("type " .. typ .. " is not yet implemented") + end + if not value then -- out + buf_t[key] = buf + end + params[i] = param + i = i + 1 + end + + buf_t[buf_param_key] = buf_param + params[length] = C.OSSL_PARAM_construct_end() + + return params +end + +local function parse(buf_t, length, types_map, types_size) + for key, buf in pairs(buf_t) do + local typ = types_map[key] + local sz = types_size and types_size[key] + + if key == buf_param_key then -- luacheck: ignore + -- ignore + elseif buf == nil or buf[0] == nil then + buf_t[key] = nil + elseif typ == "bn" then + local bn_t = ffi_new("BIGNUM*[1]") + local param = buf_t[buf_param_key][key] + if C.OSSL_PARAM_get_BN(param, bn_t) ~= 1 then + return nil, format_error("param:parse: OSSL_PARAM_get_BN") + end + buf_t[key] = bn_lib.dup(bn_t[0]) + elseif typ == OSSL_PARAM_INTEGER or + typ == OSSL_PARAM_UNSIGNED_INTEGER then + buf_t[key] = tonumber(buf[0]) + elseif typ == OSSL_PARAM_UTF8_STRING or + typ == OSSL_PARAM_OCTET_STRING then + buf_t[key] = sz and ffi_str(buf, sz) or ffi_str(buf) + elseif typ == OSSL_PARAM_UTF8_PTR or + typ == OSSL_PARAM_OCTET_PTR then + buf_t[key] = sz and ffi_str(buf[0], sz) or ffi_str(buf[0]) + elseif not typ then + return nil, "param:parse: unknown key type \"" .. key .. "\"" + else + error("type " .. typ .. " is not yet implemented") + end + end + -- for GC + buf_t[buf_param_key] = nil + + return buf_t +end + +local param_type_readable = { + [OSSL_PARAM_UNSIGNED_INTEGER] = "unsigned integer", + [OSSL_PARAM_INTEGER] = "integer", + [OSSL_PARAM_REAL] = "real number", + [OSSL_PARAM_UTF8_PTR] = "pointer to a UTF8 encoded string", + [OSSL_PARAM_UTF8_STRING] = "UTF8 encoded string", + [OSSL_PARAM_OCTET_PTR] = "pointer to an octet string", + [OSSL_PARAM_OCTET_STRING] = "octet string", +} + +local function readable_data_type(p) + local typ = p.data_type + local literal = param_type_readable[typ] + if not literal then + literal = string.format("unknown type [%d]", typ) + end + + local sz = tonumber(p.data_size) + if sz == 0 then + literal = literal .. " (arbitrary size)" + else + literal = literal .. string.format(" (max %d bytes large)", sz) + end + return literal +end + +local function parse_params_schema(params, schema, schema_readable) + if params == nil then + return nil, format_error("parse_params_schema") + end + + local i = 0 + while true do + local p = params[i] + if p.key == nil then + break + end + local key = ffi_str(p.key) + if schema then + -- TODO: don't support same key with different types for now + -- prefer string type over integer types + local typ = tonumber(p.data_type) + if schema[key] then + schema[alter_type_key] = schema[alter_type_key] or {} + schema[alter_type_key][key] = typ + else + schema[key] = typ + end + end + -- if schema_return_size then -- only non-ptr string types are needed actually + -- schema_return_size[key] = tonumber(p.return_size) + -- end + if schema_readable then + table.insert(schema_readable, { key, readable_data_type(p) }) + end + i = i + 1 + end + return schema +end + +local param_maps_set, param_maps_get = {}, {} + +local function get_params_func(typ, field) + local typ_lower = typ:sub(5):lower() + if typ_lower:sub(-4) == "_ctx" then + typ_lower = typ_lower:sub(0, -5) + end + -- field name for indexing schema, usually the (const) one created by + -- EVP_TYP_fetch or EVP_get_typebynam,e + field = field or "algo" + + local cf_settable = C[typ .. "_settable_params"] + local settable = function(self, raw) + local k = self[field] + if raw and param_maps_set[k] then + return param_maps_set[k] + end + + local param = cf_settable(self.ctx) + -- no params, this is fine, shouldn't be regarded as an error + if param == nil then + param_maps_set[k] = {} + return {} + end + local schema, schema_reabale = {}, raw and nil or {} + parse_params_schema(param, schema, schema_reabale) + param_maps_set[k] = schema + + return raw and schema or schema_reabale + end + + local cf_set = C[typ .. "_set_params"] + local set = function(self, params) + if not param_maps_set[self[field]] then + local ok, err = self:settable_params() + if not ok then + return false, typ_lower .. ":set_params: " .. err + end + end + + local oparams, err = construct(params, nil, param_maps_set[self[field]]) + if err then + return false, typ_lower .. ":set_params: " .. err + end + + if cf_set(self.ctx, oparams) ~= 1 then + return false, format_error(typ_lower .. ":set_params: " .. typ .. "_set_params") + end + + return true + end + + local cf_gettable = C[typ .. "_gettable_params"] + local gettable = function(self, raw) + local k = self[field] + if raw and param_maps_set[k] then + return param_maps_set[k] + end + + local param = cf_gettable(self.ctx) + -- no params, this is fine, shouldn't be regarded as an error + if param == nil then + param_maps_get[k] = {} + return {} + end + local schema, schema_reabale = {}, raw and nil or {} + parse_params_schema(param, schema, schema_reabale) + param_maps_set[k] = schema + + return raw and schema or schema_reabale + end + + local cf_get = C[typ .. "_get_params"] + local get_buffer, get_size_map = {}, {} + local get = function(self, key, want_size, want_type) + if not param_maps_get[self[field]] then + local ok, err = self:gettable_params() + if not ok then + return false, typ_lower .. ":set_params: " .. err + end + end + local schema = param_maps_set[self[field]] + if schema == nil or not schema[key] then -- nil or null + return nil, typ_lower .. ":get_param: unknown key \"" .. key .. "\"" + end + + table.clear(get_buffer) + table.clear(get_size_map) + get_buffer[key] = null + get_size_map[key] = want_size + schema = want_type and { [key] = want_type } or schema + + local req, err = construct(get_buffer, 1, schema, get_size_map) + if not req then + return nil, typ_lower .. ":get_param: failed to construct params: " .. err + end + + if cf_get(self.ctx, req) ~= 1 then + return nil, format_error(typ_lower .. ":get_param:get") + end + + get_buffer, err = parse(get_buffer, 1, schema, get_size_map) + if err then + return nil, typ_lower .. ":get_param: failed to parse params: " .. err + end + + return get_buffer[key] + end + + return settable, set, gettable, get +end + +return { + OSSL_PARAM_INTEGER = OSSL_PARAM_INTEGER, + OSSL_PARAM_UNSIGNED_INTEGER = OSSL_PARAM_INTEGER, + OSSL_PARAM_REAL = OSSL_PARAM_REAL, + OSSL_PARAM_UTF8_STRING = OSSL_PARAM_UTF8_STRING, + OSSL_PARAM_OCTET_STRING = OSSL_PARAM_OCTET_STRING, + OSSL_PARAM_UTF8_PTR = OSSL_PARAM_UTF8_PTR, + OSSL_PARAM_OCTET_PTR = OSSL_PARAM_OCTET_PTR, + + construct = construct, + parse = parse, + parse_params_schema = parse_params_schema, + get_params_func = get_params_func, +}
\ No newline at end of file diff --git a/server/resty/openssl/pkcs12.lua b/server/resty/openssl/pkcs12.lua new file mode 100644 index 0000000..6e3b216 --- /dev/null +++ b/server/resty/openssl/pkcs12.lua @@ -0,0 +1,168 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc +local ffi_str = ffi.string + +require "resty.openssl.include.pkcs12" +require "resty.openssl.include.bio" +local bio_util = require "resty.openssl.auxiliary.bio" +local format_error = require("resty.openssl.err").format_error +local pkey_lib = require "resty.openssl.pkey" +local x509_lib = require "resty.openssl.x509" +local stack_macro = require "resty.openssl.include.stack" +local stack_lib = require "resty.openssl.stack" +local objects_lib = require "resty.openssl.objects" +local ctx_lib = require "resty.openssl.ctx" +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X + +local stack_of_x509_new = stack_lib.new_of("X509") +local stack_of_x509_add = stack_lib.add_of("X509") +local stack_of_x509_iter = stack_lib.mt_of("X509", x509_lib.dup, {}).__ipairs + +local ptr_ptr_of_pkey = ffi.typeof("EVP_PKEY*[1]") +local ptr_ptr_of_x509 = ffi.typeof("X509*[1]") +local ptr_ptr_of_stack = ffi.typeof("OPENSSL_STACK*[1]") + +local function decode(p12, passphrase) + local bio = C.BIO_new_mem_buf(p12, #p12) + if bio == nil then + return nil, "pkcs12.decode: BIO_new_mem_buf() failed" + end + ffi_gc(bio, C.BIO_free) + + local p12 = C.d2i_PKCS12_bio(bio, nil) + if p12 == nil then + return nil, format_error("pkcs12.decode: d2i_PKCS12_bio") + end + ffi_gc(p12, C.PKCS12_free) + + local ppkey = ptr_ptr_of_pkey() + local px509 = ptr_ptr_of_x509() + local pstack = ptr_ptr_of_stack() + local stack = stack_of_x509_new() + -- assign a valid OPENSSL_STACK so gc is taken care of + pstack[0] = stack + + local code = C.PKCS12_parse(p12, passphrase or "", ppkey, px509, pstack) + if code ~= 1 then + return nil, format_error("pkcs12.decode: PKCS12_parse") + end + + local cacerts + local n = stack_macro.OPENSSL_sk_num(stack) + if n > 0 then + cacerts = {} + local iter = stack_of_x509_iter({ ctx = stack }) + for i=1, n do + local _, c = iter() + cacerts[i] = c + end + end + + local friendly_name = C.X509_alias_get0(px509[0], nil) + if friendly_name ~= nil then + friendly_name = ffi_str(friendly_name) + end + + return { + key = pkey_lib.new(ppkey[0]), + cert = x509_lib.new(px509[0]), + friendly_name = friendly_name, + cacerts = cacerts, + -- store reference to the stack, so it's not GC'ed unexpectedly + _stack = stack, + } +end + +local function encode(opts, passphrase, properties) + if passphrase and type(passphrase) ~= "string" then + return nil, "pkcs12.encode: expect passphrase to be a string" + end + local pkey = opts.key + if not pkey_lib.istype(pkey) then + return nil, "pkcs12.encode: expect key to be a pkey instance" + end + local cert = opts.cert + if not x509_lib.istype(cert) then + return nil, "pkcs12.encode: expect cert to be a x509 instance" + end + + local ok, err = cert:check_private_key(pkey) + if not ok then + return nil, "pkcs12.encode: key doesn't match cert: " .. err + end + + local nid_key = opts.nid_key + if nid_key then + nid_key, err = objects_lib.txtnid2nid(nid_key) + if err then + return nil, "pkcs12.encode: invalid nid_key" + end + end + + local nid_cert = opts.nid_cert + if nid_cert then + nid_cert, err = objects_lib.txtnid2nid(nid_cert) + if err then + return nil, "pkcs12.encode: invalid nid_cert" + end + end + + local x509stack + local cacerts = opts.cacerts + if cacerts then + if type(cacerts) ~= "table" then + return nil, "pkcs12.encode: expect cacerts to be a table" + end + if #cacerts > 0 then + -- stack lib handles gc + x509stack = stack_of_x509_new() + for _, c in ipairs(cacerts) do + if not OPENSSL_10 then + if C.X509_up_ref(c.ctx) ~= 1 then + return nil, "pkcs12.encode: failed to add cacerts: X509_up_ref failed" + end + end + local ok, err = stack_of_x509_add(x509stack, c.ctx) + if not ok then + return nil, "pkcs12.encode: failed to add cacerts: " .. err + end + end + if OPENSSL_10 then + -- OpenSSL 1.0.2 doesn't have X509_up_ref + -- shallow copy the stack, up_ref for each element + x509stack = C.X509_chain_up_ref(x509stack) + -- use the shallow gc + ffi_gc(x509stack, stack_macro.OPENSSL_sk_free) + end + end + end + + local p12 + if OPENSSL_3X then + p12 = C.PKCS12_create_ex(passphrase or "", opts.friendly_name, + pkey.ctx, cert.ctx, x509stack, + nid_key or 0, nid_cert or 0, + opts.iter or 0, opts.mac_iter or 0, 0, + ctx_lib.get_libctx(), properties) + else + p12 = C.PKCS12_create(passphrase or "", opts.friendly_name, + pkey.ctx, cert.ctx, x509stack, + nid_key or 0, nid_cert or 0, + opts.iter or 0, opts.mac_iter or 0, 0) + end + if p12 == nil then + return nil, format_error("pkcs12.encode: PKCS12_create") + end + ffi_gc(p12, C.PKCS12_free) + + return bio_util.read_wrap(C.i2d_PKCS12_bio, p12) +end + +return { + decode = decode, + loads = decode, + encode = encode, + dumps = encode, +}
\ No newline at end of file diff --git a/server/resty/openssl/pkey.lua b/server/resty/openssl/pkey.lua new file mode 100644 index 0000000..69c5aae --- /dev/null +++ b/server/resty/openssl/pkey.lua @@ -0,0 +1,942 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc +local ffi_new = ffi.new +local ffi_str = ffi.string +local ffi_cast = ffi.cast +local ffi_copy = ffi.copy + +local rsa_macro = require "resty.openssl.include.rsa" +local dh_macro = require "resty.openssl.include.dh" +require "resty.openssl.include.bio" +require "resty.openssl.include.pem" +require "resty.openssl.include.x509" +require "resty.openssl.include.evp.pkey" +local evp_macro = require "resty.openssl.include.evp" +local pkey_macro = require "resty.openssl.include.evp.pkey" +local bio_util = require "resty.openssl.auxiliary.bio" +local digest_lib = require "resty.openssl.digest" +local rsa_lib = require "resty.openssl.rsa" +local dh_lib = require "resty.openssl.dh" +local ec_lib = require "resty.openssl.ec" +local ecx_lib = require "resty.openssl.ecx" +local objects_lib = require "resty.openssl.objects" +local jwk_lib = require "resty.openssl.auxiliary.jwk" +local ctx_lib = require "resty.openssl.ctx" +local ctypes = require "resty.openssl.auxiliary.ctypes" +local format_error = require("resty.openssl.err").format_error + +local OPENSSL_11_OR_LATER = require("resty.openssl.version").OPENSSL_11_OR_LATER +local OPENSSL_111_OR_LATER = require("resty.openssl.version").OPENSSL_111_OR_LATER +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X +local BORINGSSL = require("resty.openssl.version").BORINGSSL + +local ptr_of_uint = ctypes.ptr_of_uint +local ptr_of_size_t = ctypes.ptr_of_size_t +local ptr_of_int = ctypes.ptr_of_int + +local null = ctypes.null +local load_pem_args = { null, null, null } +local load_der_args = { null } + +local get_pkey_key +if OPENSSL_11_OR_LATER then + get_pkey_key = { + [evp_macro.EVP_PKEY_RSA] = function(ctx) return C.EVP_PKEY_get0_RSA(ctx) end, + [evp_macro.EVP_PKEY_EC] = function(ctx) return C.EVP_PKEY_get0_EC_KEY(ctx) end, + [evp_macro.EVP_PKEY_DH] = function(ctx) return C.EVP_PKEY_get0_DH(ctx) end + } +else + get_pkey_key = { + [evp_macro.EVP_PKEY_RSA] = function(ctx) return ctx.pkey and ctx.pkey.rsa end, + [evp_macro.EVP_PKEY_EC] = function(ctx) return ctx.pkey and ctx.pkey.ec end, + [evp_macro.EVP_PKEY_DH] = function(ctx) return ctx.pkey and ctx.pkey.dh end, + } +end + +local load_rsa_key_funcs + +if not OPENSSL_3X then + load_rsa_key_funcs= { + ['PEM_read_bio_RSAPrivateKey'] = true, + ['PEM_read_bio_RSAPublicKey'] = true, + } -- those functions return RSA* instead of EVP_PKEY* +end + +local function load_pem_der(txt, opts, funcs) + local fmt = opts.format or '*' + if fmt ~= 'PEM' and fmt ~= 'DER' and fmt ~= "JWK" and fmt ~= '*' then + return nil, "expecting 'DER', 'PEM', 'JWK' or '*' as \"format\"" + end + + local typ = opts.type or '*' + if typ ~= 'pu' and typ ~= 'pr' and typ ~= '*' then + return nil, "expecting 'pr', 'pu' or '*' as \"type\"" + end + + if fmt == "JWK" and (typ == "pu" or type == "pr") then + return nil, "explictly load private or public key from JWK format is not supported" + end + + ngx.log(ngx.DEBUG, "load key using fmt: ", fmt, ", type: ", typ) + + local bio = C.BIO_new_mem_buf(txt, #txt) + if bio == nil then + return nil, "BIO_new_mem_buf() failed" + end + ffi_gc(bio, C.BIO_free) + + local ctx + + local fs = funcs[fmt][typ] + local passphrase_cb + for f, arg in pairs(fs) do + -- don't need BIO when loading JWK key: we parse it in Lua land + if f == "load_jwk" then + local err + ctx, err = jwk_lib[f](txt) + if ctx == nil then + -- if fmt is explictly set to JWK, we should return an error now + if fmt == "JWK" then + return nil, err + end + ngx.log(ngx.DEBUG, "jwk decode failed: ", err, ", continuing") + end + else + -- #define BIO_CTRL_RESET 1 + local code = C.BIO_ctrl(bio, 1, 0, nil) + if code ~= 1 then + return nil, "BIO_ctrl() failed" + end + + -- only pass in passphrase/passphrase_cb to PEM_* functions + if fmt == "PEM" or (fmt == "*" and arg == load_pem_args) then + if opts.passphrase then + local passphrase = opts.passphrase + if type(passphrase) ~= "string" then + -- clear errors occur when trying + C.ERR_clear_error() + return nil, "passphrase must be a string" + end + arg = { null, nil, passphrase } + elseif opts.passphrase_cb then + passphrase_cb = passphrase_cb or ffi_cast("pem_password_cb", function(buf, size) + local p = opts.passphrase_cb() + local len = #p -- 1 byte for \0 + if len > size then + ngx.log(ngx.WARN, "pkey:load_pem_der: passphrase truncated from ", len, " to ", size) + len = size + end + ffi_copy(buf, p, len) + return len + end) + arg = { null, passphrase_cb, null } + end + end + + ctx = C[f](bio, unpack(arg)) + end + + if ctx ~= nil then + ngx.log(ngx.DEBUG, "pkey:load_pem_der: loaded pkey using function ", f) + + -- pkcs1 functions create a rsa rather than evp_pkey + -- disable the checking in openssl 3.0 for sail safe + if not OPENSSL_3X and load_rsa_key_funcs[f] then + local rsa = ctx + ctx = C.EVP_PKEY_new() + if ctx == null then + return nil, format_error("pkey:load_pem_der: EVP_PKEY_new") + end + + if C.EVP_PKEY_assign(ctx, evp_macro.EVP_PKEY_RSA, rsa) ~= 1 then + C.RSA_free(rsa) + C.EVP_PKEY_free(ctx) + return nil, "pkey:load_pem_der: EVP_PKEY_assign() failed" + end + end + + break + end + end + if passphrase_cb ~= nil then + passphrase_cb:free() + end + + if ctx == nil then + return nil, format_error() + end + -- clear errors occur when trying + C.ERR_clear_error() + return ctx, nil +end + +local function generate_param(key_type, config) + if key_type == evp_macro.EVP_PKEY_DH then + local dh_group = config.group + if dh_group then + local get_group_func = dh_macro.dh_groups[dh_group] + if not get_group_func then + return nil, "unknown pre-defined group " .. dh_group + end + local ctx = get_group_func() + if ctx == nil then + return nil, format_error("DH_get_x") + end + local params = C.EVP_PKEY_new() + if not params then + return nil, format_error("EVP_PKEY_new") + end + ffi_gc(params, C.EVP_PKEY_free) + if C.EVP_PKEY_assign(params, key_type, ctx) ~= 1 then + return nil, format_error("EVP_PKEY_assign") + end + return params + end + end + + local pctx = C.EVP_PKEY_CTX_new_id(key_type, nil) + if pctx == nil then + return nil, format_error("EVP_PKEY_CTX_new_id") + end + ffi_gc(pctx, C.EVP_PKEY_CTX_free) + + if C.EVP_PKEY_paramgen_init(pctx) ~= 1 then + return nil, format_error("EVP_PKEY_paramgen_init") + end + + if key_type == evp_macro.EVP_PKEY_EC then + local curve = config.curve or 'prime192v1' + local nid = C.OBJ_ln2nid(curve) + if nid == 0 then + return nil, "unknown curve " .. curve + end + if pkey_macro.EVP_PKEY_CTX_set_ec_paramgen_curve_nid(pctx, nid) <= 0 then + return nil, format_error("EVP_PKEY_CTX_ctrl: EC: curve_nid") + end + if not BORINGSSL then + -- use the named-curve encoding for best backward-compatibilty + -- and for playing well with go:crypto/x509 + -- # define OPENSSL_EC_NAMED_CURVE 0x001 + if pkey_macro.EVP_PKEY_CTX_set_ec_param_enc(pctx, 1) <= 0 then + return nil, format_error("EVP_PKEY_CTX_ctrl: EC: param_enc") + end + end + elseif key_type == evp_macro.EVP_PKEY_DH then + local bits = config.bits + if not config.param and not bits then + bits = 2048 + end + if bits and pkey_macro.EVP_PKEY_CTX_set_dh_paramgen_prime_len(pctx, bits) <= 0 then + return nil, format_error("EVP_PKEY_CTX_ctrl: DH: bits") + end + end + + local ctx_ptr = ffi_new("EVP_PKEY*[1]") + if C.EVP_PKEY_paramgen(pctx, ctx_ptr) ~= 1 then + return nil, format_error("EVP_PKEY_paramgen") + end + + local params = ctx_ptr[0] + ffi_gc(params, C.EVP_PKEY_free) + + return params +end + +local load_param_funcs = { + [evp_macro.EVP_PKEY_EC] = { + ["*"] = { + ["*"] = { + ['PEM_read_bio_ECPKParameters'] = load_pem_args, + -- ['d2i_ECPKParameters_bio'] = load_der_args, + } + }, + }, + [evp_macro.EVP_PKEY_DH] = { + ["*"] = { + ["*"] = { + ['PEM_read_bio_DHparams'] = load_pem_args, + -- ['d2i_DHparams_bio'] = load_der_args, + } + }, + }, +} + +local function generate_key(config) + local typ = config.type or 'RSA' + local key_type + + if typ == "RSA" then + key_type = evp_macro.EVP_PKEY_RSA + elseif typ == "EC" then + key_type = evp_macro.EVP_PKEY_EC + elseif typ == "DH" then + key_type = evp_macro.EVP_PKEY_DH + elseif evp_macro.ecx_curves[typ] then + key_type = evp_macro.ecx_curves[typ] + else + return nil, "unsupported type " .. typ + end + if key_type == 0 then + return nil, "the linked OpenSSL library doesn't support " .. typ .. " key" + end + + local pctx + + if key_type == evp_macro.EVP_PKEY_EC or key_type == evp_macro.EVP_PKEY_DH then + local params, err + if config.param then + -- HACK + config.type = nil + local ctx, err = load_pem_der(config.param, config, load_param_funcs[key_type]) + if err then + return nil, "load_pem_der: " .. err + end + if key_type == evp_macro.EVP_PKEY_EC then + local ec_group = ctx + ffi_gc(ec_group, C.EC_GROUP_free) + ctx = C.EC_KEY_new() + if ctx == nil then + return nil, "EC_KEY_new() failed" + end + if C.EC_KEY_set_group(ctx, ec_group) ~= 1 then + return nil, format_error("EC_KEY_set_group") + end + end + params = C.EVP_PKEY_new() + if not params then + return nil, format_error("EVP_PKEY_new") + end + ffi_gc(params, C.EVP_PKEY_free) + if C.EVP_PKEY_assign(params, key_type, ctx) ~= 1 then + return nil, format_error("EVP_PKEY_assign") + end + else + params, err = generate_param(key_type, config) + if err then + return nil, "generate_param: " .. err + end + end + pctx = C.EVP_PKEY_CTX_new(params, nil) + if pctx == nil then + return nil, format_error("EVP_PKEY_CTX_new") + end + else + pctx = C.EVP_PKEY_CTX_new_id(key_type, nil) + if pctx == nil then + return nil, format_error("EVP_PKEY_CTX_new_id") + end + end + + ffi_gc(pctx, C.EVP_PKEY_CTX_free) + + if C.EVP_PKEY_keygen_init(pctx) ~= 1 then + return nil, format_error("EVP_PKEY_keygen_init") + end + -- RSA key parameters are set for keygen ctx not paramgen + if key_type == evp_macro.EVP_PKEY_RSA then + local bits = config.bits or 2048 + if bits > 4294967295 then + return nil, "bits out of range" + end + + if pkey_macro.EVP_PKEY_CTX_set_rsa_keygen_bits(pctx, bits) <= 0 then + return nil, format_error("EVP_PKEY_CTX_ctrl: RSA: bits") + end + + if config.exp then + -- don't free exp as it's used internally in key + local exp = C.BN_new() + if exp == nil then + return nil, "BN_new() failed" + end + C.BN_set_word(exp, config.exp) + + if pkey_macro.EVP_PKEY_CTX_set_rsa_keygen_pubexp(pctx, exp) <= 0 then + return nil, format_error("EVP_PKEY_CTX_ctrl: RSA: exp") + end + end + end + local ctx_ptr = ffi_new("EVP_PKEY*[1]") + -- TODO: move to use EVP_PKEY_gen after drop support for <1.1.1 + if C.EVP_PKEY_keygen(pctx, ctx_ptr) ~= 1 then + return nil, format_error("EVP_PKEY_gen") + end + return ctx_ptr[0] +end + +local load_key_try_funcs = {} do + -- TODO: pkcs1 load functions are not required in openssl 3.0 + local _load_key_try_funcs = { + PEM = { + -- Note: make sure we always try load priv key first + pr = { + ['PEM_read_bio_PrivateKey'] = load_pem_args, + -- disable in openssl3.0, PEM_read_bio_PrivateKey can read pkcs1 in 3.0 + ['PEM_read_bio_RSAPrivateKey'] = not OPENSSL_3X and load_pem_args or nil, + }, + pu = { + ['PEM_read_bio_PUBKEY'] = load_pem_args, + -- disable in openssl3.0, PEM_read_bio_PrivateKey can read pkcs1 in 3.0 + ['PEM_read_bio_RSAPublicKey'] = not OPENSSL_3X and load_pem_args or nil, + }, + }, + DER = { + pr = { ['d2i_PrivateKey_bio'] = load_der_args, }, + pu = { ['d2i_PUBKEY_bio'] = load_der_args, }, + }, + JWK = { + pr = { ['load_jwk'] = {}, }, + } + } + -- populate * funcs + local all_funcs = {} + local typ_funcs = {} + for fmt, ffs in pairs(_load_key_try_funcs) do + load_key_try_funcs[fmt] = ffs + + local funcs = {} + for typ, fs in pairs(ffs) do + for f, arg in pairs(fs) do + funcs[f] = arg + all_funcs[f] = arg + if not typ_funcs[typ] then + typ_funcs[typ] = {} + end + typ_funcs[typ] = arg + end + end + load_key_try_funcs[fmt]["*"] = funcs + end + load_key_try_funcs["*"] = {} + load_key_try_funcs["*"]["*"] = all_funcs + for typ, fs in pairs(typ_funcs) do + load_key_try_funcs[typ] = fs + end +end + +local function __tostring(self, is_priv, fmt, is_pkcs1) + if fmt == "JWK" then + return jwk_lib.dump_jwk(self, is_priv) + elseif is_pkcs1 then + if fmt ~= "PEM" or self.key_type ~= evp_macro.EVP_PKEY_RSA then + return nil, "PKCS#1 format is only supported to encode RSA key in \"PEM\" format" + elseif OPENSSL_3X then -- maybe possible with OSSL_ENCODER_CTX_new_for_pkey though + return nil, "writing out RSA key in PKCS#1 format is not supported in OpenSSL 3.0" + end + end + if is_priv then + if fmt == "DER" then + return bio_util.read_wrap(C.i2d_PrivateKey_bio, self.ctx) + end + -- PEM + if is_pkcs1 then + local rsa = get_pkey_key[evp_macro.EVP_PKEY_RSA](self.ctx) + if rsa == nil then + return nil, "unable to read RSA key for writing" + end + return bio_util.read_wrap(C.PEM_write_bio_RSAPrivateKey, + rsa, + nil, nil, 0, nil, nil) + end + return bio_util.read_wrap(C.PEM_write_bio_PrivateKey, + self.ctx, + nil, nil, 0, nil, nil) + else + if fmt == "DER" then + return bio_util.read_wrap(C.i2d_PUBKEY_bio, self.ctx) + end + -- PEM + if is_pkcs1 then + local rsa = get_pkey_key[evp_macro.EVP_PKEY_RSA](self.ctx) + if rsa == nil then + return nil, "unable to read RSA key for writing" + end + return bio_util.read_wrap(C.PEM_write_bio_RSAPublicKey, rsa) + end + return bio_util.read_wrap(C.PEM_write_bio_PUBKEY, self.ctx) + end + +end + +local _M = {} +local mt = { __index = _M, __tostring = __tostring } + +local empty_table = {} +local evp_pkey_ptr_ct = ffi.typeof('EVP_PKEY*') + +function _M.new(s, opts) + local ctx, err + s = s or {} + if type(s) == 'table' then + ctx, err = generate_key(s) + if err then + err = "pkey.new:generate_key: " .. err + end + elseif type(s) == 'string' then + ctx, err = load_pem_der(s, opts or empty_table, load_key_try_funcs) + if err then + err = "pkey.new:load_key: " .. err + end + elseif type(s) == 'cdata' then + if ffi.istype(evp_pkey_ptr_ct, s) then + ctx = s + else + return nil, "pkey.new: expect a EVP_PKEY* cdata at #1" + end + else + return nil, "pkey.new: unexpected type " .. type(s) .. " at #1" + end + + if err then + return nil, err + end + + ffi_gc(ctx, C.EVP_PKEY_free) + + local key_type = OPENSSL_3X and C.EVP_PKEY_get_base_id(ctx) or C.EVP_PKEY_base_id(ctx) + if key_type == 0 then + return nil, "pkey.new: cannot get key_type" + end + local key_type_is_ecx = (key_type == evp_macro.EVP_PKEY_ED25519) or + (key_type == evp_macro.EVP_PKEY_X25519) or + (key_type == evp_macro.EVP_PKEY_ED448) or + (key_type == evp_macro.EVP_PKEY_X448) + + -- although OpenSSL discourages to use this size for digest/verify + -- but this is good enough for now + local buf_size = OPENSSL_3X and C.EVP_PKEY_get_size(ctx) or C.EVP_PKEY_size(ctx) + + local self = setmetatable({ + ctx = ctx, + pkey_ctx = nil, + rsa_padding = nil, + key_type = key_type, + key_type_is_ecx = key_type_is_ecx, + buf = ctypes.uchar_array(buf_size), + buf_size = buf_size, + }, mt) + + return self, nil +end + +function _M.istype(l) + return l and l.ctx and ffi.istype(evp_pkey_ptr_ct, l.ctx) +end + +function _M:get_key_type() + return objects_lib.nid2table(self.key_type) +end + +function _M:get_default_digest_type() + if BORINGSSL then + return nil, "BoringSSL doesn't have default digest for pkey" + end + + local nid = ptr_of_int() + local code = C.EVP_PKEY_get_default_digest_nid(self.ctx, nid) + if code == -2 then + return nil, "operation is not supported by the public key algorithm" + elseif code <= 0 then + return nil, format_error("get_default_digest", code) + end + + local ret = objects_lib.nid2table(nid[0]) + ret.mandatory = code == 2 + return ret +end + +function _M:get_provider_name() + if not OPENSSL_3X then + return false, "pkey:get_provider_name is not supported" + end + local p = C.EVP_PKEY_get0_provider(self.ctx) + if p == nil then + return nil + end + return ffi_str(C.OSSL_PROVIDER_get0_name(p)) +end + +if OPENSSL_3X then + local param_lib = require "resty.openssl.param" + _M.settable_params, _M.set_params, _M.gettable_params, _M.get_param = param_lib.get_params_func("EVP_PKEY", "key_type") +end + +function _M:get_parameters() + if not self.key_type_is_ecx then + local getter = get_pkey_key[self.key_type] + if not getter then + return nil, "key getter not defined" + end + local key = getter(self.ctx) + if key == nil then + return nil, format_error("EVP_PKEY_get0_{key}") + end + + if self.key_type == evp_macro.EVP_PKEY_RSA then + return rsa_lib.get_parameters(key) + elseif self.key_type == evp_macro.EVP_PKEY_EC then + return ec_lib.get_parameters(key) + elseif self.key_type == evp_macro.EVP_PKEY_DH then + return dh_lib.get_parameters(key) + end + else + return ecx_lib.get_parameters(self.ctx) + end +end + +function _M:set_parameters(opts) + if not self.key_type_is_ecx then + local getter = get_pkey_key[self.key_type] + if not getter then + return nil, "key getter not defined" + end + local key = getter(self.ctx) + if key == nil then + return nil, format_error("EVP_PKEY_get0_{key}") + end + + if self.key_type == evp_macro.EVP_PKEY_RSA then + return rsa_lib.set_parameters(key, opts) + elseif self.key_type == evp_macro.EVP_PKEY_EC then + return ec_lib.set_parameters(key, opts) + elseif self.key_type == evp_macro.EVP_PKEY_DH then + return dh_lib.set_parameters(key, opts) + end + else + -- for ecx keys we always create a new EVP_PKEY and release the old one + local ctx, err = ecx_lib.set_parameters(self.key_type, self.ctx, opts) + if err then + return false, err + end + self.ctx = ctx + end +end + +function _M:is_private() + local params = self:get_parameters() + if self.key_type == evp_macro.EVP_PKEY_RSA then + return params.d ~= nil + else + return params.private ~= nil + end +end + +local ASYMMETRIC_OP_ENCRYPT = 0x1 +local ASYMMETRIC_OP_DECRYPT = 0x2 +local ASYMMETRIC_OP_SIGN_RAW = 0x4 +local ASYMMETRIC_OP_VERIFY_RECOVER = 0x8 + +local function asymmetric_routine(self, s, op, padding) + local pkey_ctx + + if self.key_type == evp_macro.EVP_PKEY_RSA then + if padding then + padding = tonumber(padding) + if not padding then + return nil, "invalid padding: " .. __tostring(padding) + end + else + padding = rsa_macro.paddings.RSA_PKCS1_PADDING + end + end + + if self.pkey_ctx ~= nil and + (self.key_type ~= evp_macro.EVP_PKEY_RSA or self.rsa_padding == padding) then + pkey_ctx = self.pkey_ctx + else + pkey_ctx = C.EVP_PKEY_CTX_new(self.ctx, nil) + if pkey_ctx == nil then + return nil, format_error("pkey:asymmetric_routine EVP_PKEY_CTX_new()") + end + ffi_gc(pkey_ctx, C.EVP_PKEY_CTX_free) + self.pkey_ctx = pkey_ctx + end + + local f, fint, op_name + if op == ASYMMETRIC_OP_ENCRYPT then + fint = C.EVP_PKEY_encrypt_init + f = C.EVP_PKEY_encrypt + op_name = "encrypt" + elseif op == ASYMMETRIC_OP_DECRYPT then + fint = C.EVP_PKEY_decrypt_init + f = C.EVP_PKEY_decrypt + op_name = "decrypt" + elseif op == ASYMMETRIC_OP_SIGN_RAW then + fint = C.EVP_PKEY_sign_init + f = C.EVP_PKEY_sign + op_name = "sign" + elseif op == ASYMMETRIC_OP_VERIFY_RECOVER then + fint = C.EVP_PKEY_verify_recover_init + f = C.EVP_PKEY_verify_recover + op_name = "verify_recover" + else + error("bad \"op\", got " .. op, 2) + end + + local code = fint(pkey_ctx) + if code < 1 then + return nil, format_error("pkey:asymmetric_routine EVP_PKEY_" .. op_name .. "_init", code) + end + + -- EVP_PKEY_CTX_ctrl must be called after *_init + if self.key_type == evp_macro.EVP_PKEY_RSA and padding then + if pkey_macro.EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, padding) ~= 1 then + return nil, format_error("pkey:asymmetric_routine EVP_PKEY_CTX_set_rsa_padding") + end + self.rsa_padding = padding + end + + local length = ptr_of_size_t(self.buf_size) + + if f(pkey_ctx, self.buf, length, s, #s) <= 0 then + return nil, format_error("pkey:asymmetric_routine EVP_PKEY_" .. op_name) + end + + return ffi_str(self.buf, length[0]), nil +end + +_M.PADDINGS = rsa_macro.paddings + +function _M:encrypt(s, padding) + return asymmetric_routine(self, s, ASYMMETRIC_OP_ENCRYPT, padding) +end + +function _M:decrypt(s, padding) + return asymmetric_routine(self, s, ASYMMETRIC_OP_DECRYPT, padding) +end + +function _M:sign_raw(s, padding) + -- TODO: temporary hack before OpenSSL has proper check for existence of private key + if self.key_type_is_ecx and not self:is_private() then + return nil, "pkey:sign_raw: missing private key" + end + + return asymmetric_routine(self, s, ASYMMETRIC_OP_SIGN_RAW, padding) +end + +function _M:verify_recover(s, padding) + return asymmetric_routine(self, s, ASYMMETRIC_OP_VERIFY_RECOVER, padding) +end + +local evp_pkey_ctx_ptr_ptr_ct = ffi.typeof('EVP_PKEY_CTX*[1]') + +local function sign_verify_prepare(self, fint, md_alg, padding, opts) + local pkey_ctx + + if self.key_type == evp_macro.EVP_PKEY_RSA and padding then + pkey_ctx = C.EVP_PKEY_CTX_new(self.ctx, nil) + if pkey_ctx == nil then + return nil, format_error("pkey:sign_verify_prepare EVP_PKEY_CTX_new()") + end + ffi_gc(pkey_ctx, C.EVP_PKEY_CTX_free) + end + + local md_ctx = C.EVP_MD_CTX_new() + if md_ctx == nil then + return nil, "pkey:sign_verify_prepare: EVP_MD_CTX_new() failed" + end + ffi_gc(md_ctx, C.EVP_MD_CTX_free) + + local algo + if md_alg then + if OPENSSL_3X then + algo = C.EVP_MD_fetch(ctx_lib.get_libctx(), md_alg, nil) + else + algo = C.EVP_get_digestbyname(md_alg) + end + if algo == nil then + return nil, string.format("pkey:sign_verify_prepare: invalid digest type \"%s\"", md_alg) + end + end + + local ppkey_ctx = evp_pkey_ctx_ptr_ptr_ct() + ppkey_ctx[0] = pkey_ctx + if fint(md_ctx, ppkey_ctx, algo, nil, self.ctx) ~= 1 then + return nil, format_error("pkey:sign_verify_prepare: Init failed") + end + + if self.key_type == evp_macro.EVP_PKEY_RSA then + if padding then + if pkey_macro.EVP_PKEY_CTX_set_rsa_padding(ppkey_ctx[0], padding) ~= 1 then + return nil, format_error("pkey:sign_verify_prepare EVP_PKEY_CTX_set_rsa_padding") + end + end + if opts and opts.pss_saltlen and padding ~= rsa_macro.paddings.RSA_PKCS1_PSS_PADDING then + if pkey_macro.EVP_PKEY_CTX_set_rsa_pss_saltlen(ppkey_ctx[0], opts.pss_saltlen) ~= 1 then + return nil, format_error("pkey:sign_verify_prepare EVP_PKEY_CTX_set_rsa_pss_saltlen") + end + end + end + + return md_ctx +end + +function _M:sign(digest, md_alg, padding, opts) + -- TODO: temporary hack before OpenSSL has proper check for existence of private key + if self.key_type_is_ecx and not self:is_private() then + return nil, "pkey:sign: missing private key" + end + + if digest_lib.istype(digest) then + local length = ptr_of_uint() + if C.EVP_SignFinal(digest.ctx, self.buf, length, self.ctx) ~= 1 then + return nil, format_error("pkey:sign: EVP_SignFinal") + end + return ffi_str(self.buf, length[0]), nil + elseif type(digest) == "string" then + if not OPENSSL_111_OR_LATER and not BORINGSSL then + -- we can still support earilier version with *Update and *Final + -- but we choose to not relying on the legacy interface for simplicity + return nil, "pkey:sign: new-style sign only available in OpenSSL 1.1.1 (or BoringSSL 1.1.0) or later" + elseif BORINGSSL and not md_alg and not self.key_type_is_ecx then + return nil, "pkey:sign: BoringSSL doesn't provide default digest, md_alg must be specified" + end + + local md_ctx, err = sign_verify_prepare(self, C.EVP_DigestSignInit, md_alg, padding, opts) + if err then + return nil, err + end + + local length = ptr_of_size_t(self.buf_size) + if C.EVP_DigestSign(md_ctx, self.buf, length, digest, #digest) ~= 1 then + return nil, format_error("pkey:sign: EVP_DigestSign") + end + return ffi_str(self.buf, length[0]), nil + else + return nil, "pkey:sign: expect a digest instance or a string at #1" + end +end + +function _M:verify(signature, digest, md_alg, padding, opts) + if type(signature) ~= "string" then + return nil, "pkey:verify: expect a string at #1" + end + + local code + if digest_lib.istype(digest) then + code = C.EVP_VerifyFinal(digest.ctx, signature, #signature, self.ctx) + elseif type(digest) == "string" then + if not OPENSSL_111_OR_LATER and not BORINGSSL then + -- we can still support earilier version with *Update and *Final + -- but we choose to not relying on the legacy interface for simplicity + return nil, "pkey:verify: new-style verify only available in OpenSSL 1.1.1 (or BoringSSL 1.1.0) or later" + elseif BORINGSSL and not md_alg and not self.key_type_is_ecx then + return nil, "pkey:verify: BoringSSL doesn't provide default digest, md_alg must be specified" + end + + local md_ctx, err = sign_verify_prepare(self, C.EVP_DigestVerifyInit, md_alg, padding, opts) + if err then + return nil, err + end + + code = C.EVP_DigestVerify(md_ctx, signature, #signature, digest, #digest) + else + return nil, "pkey:verify: expect a digest instance or a string at #2" + end + + if code == 0 then + return false, nil + elseif code == 1 then + return true, nil + end + return false, format_error("pkey:verify") +end + +function _M:derive(peerkey) + if not self.istype(peerkey) then + return nil, "pkey:derive: expect a pkey instance at #1" + end + local pctx = C.EVP_PKEY_CTX_new(self.ctx, nil) + if pctx == nil then + return nil, "pkey:derive: EVP_PKEY_CTX_new() failed" + end + ffi_gc(pctx, C.EVP_PKEY_CTX_free) + local code = C.EVP_PKEY_derive_init(pctx) + if code <= 0 then + return nil, format_error("pkey:derive: EVP_PKEY_derive_init", code) + end + + code = C.EVP_PKEY_derive_set_peer(pctx, peerkey.ctx) + if code <= 0 then + return nil, format_error("pkey:derive: EVP_PKEY_derive_set_peer", code) + end + + local buflen = ptr_of_size_t() + code = C.EVP_PKEY_derive(pctx, nil, buflen) + if code <= 0 then + return nil, format_error("pkey:derive: EVP_PKEY_derive check buffer size", code) + end + + local buf = ctypes.uchar_array(buflen[0]) + code = C.EVP_PKEY_derive(pctx, buf, buflen) + if code <= 0 then + return nil, format_error("pkey:derive: EVP_PKEY_derive", code) + end + + return ffi_str(buf, buflen[0]) +end + +local function pub_or_priv_is_pri(pub_or_priv) + if pub_or_priv == 'private' or pub_or_priv == 'PrivateKey' then + return true + elseif not pub_or_priv or pub_or_priv == 'public' or pub_or_priv == 'PublicKey' then + return false + else + return nil, string.format("can only export private or public key, not %s", pub_or_priv) + end +end + +function _M:tostring(pub_or_priv, fmt, pkcs1) + local is_priv, err = pub_or_priv_is_pri(pub_or_priv) + if err then + return nil, "pkey:tostring: " .. err + end + return __tostring(self, is_priv, fmt, pkcs1) +end + +function _M:to_PEM(pub_or_priv, pkcs1) + return self:tostring(pub_or_priv, "PEM", pkcs1) +end + +function _M.paramgen(config) + local typ = config.type + local key_type, write_func, get_ctx_func + if typ == "EC" then + key_type = evp_macro.EVP_PKEY_EC + if key_type == 0 then + return nil, "pkey.paramgen: the linked OpenSSL library doesn't support EC key" + end + write_func = C.PEM_write_bio_ECPKParameters + get_ctx_func = function(ctx) + local ctx = get_pkey_key[key_type](ctx) + if ctx == nil then + error(format_error("pkey.paramgen: EVP_PKEY_get0_{key}")) + end + return C.EC_KEY_get0_group(ctx) + end + elseif typ == "DH" then + key_type = evp_macro.EVP_PKEY_DH + if key_type == 0 then + return nil, "pkey.paramgen: the linked OpenSSL library doesn't support DH key" + end + write_func = C.PEM_write_bio_DHparams + get_ctx_func = get_pkey_key[key_type] + else + return nil, "pkey.paramgen: unsupported type " .. type + end + + local params, err = generate_param(key_type, config) + if err then + return nil, "pkey.paramgen: generate_param: " .. err + end + + local ctx = get_ctx_func(params) + if ctx == nil then + return nil, format_error("pkey.paramgen: EVP_PKEY_get0_{key}") + end + + return bio_util.read_wrap(write_func, ctx) +end + +return _M diff --git a/server/resty/openssl/provider.lua b/server/resty/openssl/provider.lua new file mode 100644 index 0000000..2879ac3 --- /dev/null +++ b/server/resty/openssl/provider.lua @@ -0,0 +1,136 @@ +local ffi = require "ffi" +local C = ffi.C + +require "resty.openssl.include.provider" +local param_lib = require "resty.openssl.param" +local ctx_lib = require "resty.openssl.ctx" +local null = require("resty.openssl.auxiliary.ctypes").null +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X +local format_error = require("resty.openssl.err").format_error + +if not OPENSSL_3X then + error("provider is only supported since OpenSSL 3.0") +end + +local _M = {} +local mt = {__index = _M} + +local ossl_provider_ctx_ct = ffi.typeof('OSSL_PROVIDER*') + +function _M.load(name, try) + local ctx + local libctx = ctx_lib.get_libctx() + if try then + ctx = C.OSSL_PROVIDER_try_load(libctx, name) + if ctx == nil then + return nil, format_error("provider.try_load") + end + else + ctx = C.OSSL_PROVIDER_load(libctx, name) + if ctx == nil then + return nil, format_error("provider.load") + end + end + + return setmetatable({ + ctx = ctx, + param_types = nil, + }, mt), nil +end + +function _M.set_default_search_path(path) + C.OSSL_PROVIDER_set_default_search_path(ctx_lib.get_libctx(), path) +end + +function _M.is_available(name) + return C.OSSL_PROVIDER_available(ctx_lib.get_libctx(), name) == 1 +end + +function _M.istype(l) + return l and l.ctx and ffi.istype(ossl_provider_ctx_ct, l.ctx) +end + +function _M:unload() + if C.OSSL_PROVIDER_unload(self.ctx) == nil then + return false, format_error("provider:unload") + end + return true +end + +function _M:self_test() + if C.OSSL_PROVIDER_self_test(self.ctx) == nil then + return false, format_error("provider:self_test") + end + return true +end + +local params_well_known = { + -- Well known parameter names that core passes to providers + ["openssl-version"] = param_lib.OSSL_PARAM_UTF8_PTR, + ["provider-name"] = param_lib.OSSL_PARAM_UTF8_PTR, + ["module-filename"] = param_lib.OSSL_PARAM_UTF8_PTR, + + -- Well known parameter names that Providers can define + ["name"] = param_lib.OSSL_PARAM_UTF8_PTR, + ["version"] = param_lib.OSSL_PARAM_UTF8_PTR, + ["buildinfo"] = param_lib.OSSL_PARAM_UTF8_PTR, + ["status"] = param_lib.OSSL_PARAM_INTEGER, + ["security-checks"] = param_lib.OSSL_PARAM_INTEGER, +} + +local function load_gettable_names(ctx) + local schema = {} + for k, v in pairs(params_well_known) do + schema[k] = v + end + + local err + schema, err = param_lib.parse_params_schema( + C.OSSL_PROVIDER_gettable_params(ctx), schema) + if err then + return nil, err + end + + return schema +end + +function _M:get_params(...) + local keys = {...} + local key_length = #keys + if key_length == 0 then + return nil, "provider:get_params: at least one key is required" + end + + if not self.param_types then + local param_types, err = load_gettable_names(self.ctx) + if err then + return nil, "provider:get_params: " .. err + end + self.param_types = param_types + end + + local buffers = {} + for _, key in ipairs(keys) do + buffers[key] = null + end + local req, err = param_lib.construct(buffers, key_length, self.param_types) + if not req then + return nil, "provider:get_params: failed to construct params: " .. err + end + + if C.OSSL_PROVIDER_get_params(self.ctx, req) ~= 1 then + return nil, format_error("provider:get_params") + end + + buffers, err = param_lib.parse(buffers, key_length, self.param_types) + if err then + return nil, "provider:get_params: failed to parse params: " .. err + end + + if key_length == 1 then + return buffers[keys[1]] + end + return buffers +end + +return _M diff --git a/server/resty/openssl/rand.lua b/server/resty/openssl/rand.lua new file mode 100644 index 0000000..be54da9 --- /dev/null +++ b/server/resty/openssl/rand.lua @@ -0,0 +1,51 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_str = ffi.string + +require "resty.openssl.include.rand" +local ctx_lib = require "resty.openssl.ctx" +local ctypes = require "resty.openssl.auxiliary.ctypes" +local format_error = require("resty.openssl.err").format_error +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X + +local buf +local buf_size = 0 +local function bytes(length, private, strength) + if type(length) ~= "number" then + return nil, "rand.bytes: expect a number at #1" + elseif strength and type(strength) ~= "number" then + return nil, "rand.bytes: expect a number at #3" + end + -- generally we don't need manually reseed rng + -- https://www.openssl.org/docs/man1.1.1/man3/RAND_seed.html + + -- initialize or resize buffer + if not buf or buf_size < length then + buf = ctypes.uchar_array(length) + buf_size = length + end + + local code + if OPENSSL_3X then + if private then + code = C.RAND_priv_bytes_ex(ctx_lib.get_libctx(), buf, length, strength or 0) + else + code = C.RAND_bytes_ex(ctx_lib.get_libctx(), buf, length, strength or 0) + end + else + if private then + code = C.RAND_priv_bytes(buf, length) + else + code = C.RAND_bytes(buf, length) + end + end + if code ~= 1 then + return nil, format_error("rand.bytes", code) + end + + return ffi_str(buf, length) +end + +return { + bytes = bytes, +} diff --git a/server/resty/openssl/rsa.lua b/server/resty/openssl/rsa.lua new file mode 100644 index 0000000..f3af394 --- /dev/null +++ b/server/resty/openssl/rsa.lua @@ -0,0 +1,155 @@ +local ffi = require "ffi" +local C = ffi.C + +local bn_lib = require "resty.openssl.bn" + +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local OPENSSL_11_OR_LATER = require("resty.openssl.version").OPENSSL_11_OR_LATER +local format_error = require("resty.openssl.err").format_error + +local _M = {} + +_M.params = {"n", "e", "d", "p", "q", "dmp1", "dmq1", "iqmp"} + +local empty_table = {} +local bn_ptrptr_ct = ffi.typeof("const BIGNUM *[1]") +function _M.get_parameters(rsa_st) + -- {"n", "e", "d", "p", "q", "dmp1", "dmq1", "iqmp"} + return setmetatable(empty_table, { + __index = function(_, k) + local ptr, ret + if OPENSSL_11_OR_LATER then + ptr = bn_ptrptr_ct() + end + + if k == 'n' then + if OPENSSL_11_OR_LATER then + C.RSA_get0_key(rsa_st, ptr, nil, nil) + end + elseif k == 'e' then + if OPENSSL_11_OR_LATER then + C.RSA_get0_key(rsa_st, nil, ptr, nil) + end + elseif k == 'd' then + if OPENSSL_11_OR_LATER then + C.RSA_get0_key(rsa_st, nil, nil, ptr) + end + elseif k == 'p' then + if OPENSSL_11_OR_LATER then + C.RSA_get0_factors(rsa_st, ptr, nil) + end + elseif k == 'q' then + if OPENSSL_11_OR_LATER then + C.RSA_get0_factors(rsa_st, nil, ptr) + end + elseif k == 'dmp1' then + if OPENSSL_11_OR_LATER then + C.RSA_get0_crt_params(rsa_st, ptr, nil, nil) + end + elseif k == 'dmq1' then + if OPENSSL_11_OR_LATER then + C.RSA_get0_crt_params(rsa_st, nil, ptr, nil) + end + elseif k == 'iqmp' then + if OPENSSL_11_OR_LATER then + C.RSA_get0_crt_params(rsa_st, nil, nil, ptr) + end + else + return nil, "rsa.get_parameters: unknown parameter \"" .. k .. "\" for RSA key" + end + + if OPENSSL_11_OR_LATER then + ret = ptr[0] + elseif OPENSSL_10 then + ret = rsa_st[k] + end + + if ret == nil then + return nil + end + return bn_lib.dup(ret) + end + }), nil +end + +local function dup_bn_value(v) + if not bn_lib.istype(v) then + return nil, "expect value to be a bn instance" + end + local bn = C.BN_dup(v.ctx) + if bn == nil then + return nil, "BN_dup() failed" + end + return bn +end + +function _M.set_parameters(rsa_st, opts) + local err + local opts_bn = {} + -- remember which parts of BNs has been added to rsa_st, they should be freed + -- by RSA_free and we don't cleanup them on failure + local cleanup_from_idx = 1 + -- dup input + local do_set_key, do_set_factors, do_set_crt_params + for k, v in pairs(opts) do + opts_bn[k], err = dup_bn_value(v) + if err then + err = "rsa.set_parameters: cannot process parameter \"" .. k .. "\":" .. err + goto cleanup_with_error + end + if k == "n" or k == "e" or k == "d" then + do_set_key = true + elseif k == "p" or k == "q" then + do_set_factors = true + elseif k == "dmp1" or k == "dmq1" or k == "iqmp" then + do_set_crt_params = true + end + end + if OPENSSL_11_OR_LATER then + -- "The values n and e must be non-NULL the first time this function is called on a given RSA object." + -- thus we force to set them together + local code + if do_set_key then + code = C.RSA_set0_key(rsa_st, opts_bn["n"], opts_bn["e"], opts_bn["d"]) + if code == 0 then + err = format_error("rsa.set_parameters: RSA_set0_key") + goto cleanup_with_error + end + end + cleanup_from_idx = cleanup_from_idx + 3 + if do_set_factors then + code = C.RSA_set0_factors(rsa_st, opts_bn["p"], opts_bn["q"]) + if code == 0 then + err = format_error("rsa.set_parameters: RSA_set0_factors") + goto cleanup_with_error + end + end + cleanup_from_idx = cleanup_from_idx + 2 + if do_set_crt_params then + code = C.RSA_set0_crt_params(rsa_st, opts_bn["dmp1"], opts_bn["dmq1"], opts_bn["iqmp"]) + if code == 0 then + err = format_error("rsa.set_parameters: RSA_set0_crt_params") + goto cleanup_with_error + end + end + return true + elseif OPENSSL_10 then + for k, v in pairs(opts_bn) do + if rsa_st[k] ~= nil then + C.BN_free(rsa_st[k]) + end + rsa_st[k]= v + end + return true + end + +::cleanup_with_error:: + for i, k in pairs(_M.params) do + if i >= cleanup_from_idx then + C.BN_free(opts_bn[k]) + end + end + return false, err +end + +return _M diff --git a/server/resty/openssl/ssl.lua b/server/resty/openssl/ssl.lua new file mode 100644 index 0000000..d3eee90 --- /dev/null +++ b/server/resty/openssl/ssl.lua @@ -0,0 +1,353 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_str = ffi.string +local ffi_cast = ffi.cast + +require "resty.openssl.include.ssl" + +local nginx_aux = require("resty.openssl.auxiliary.nginx") +local x509_lib = require("resty.openssl.x509") +local chain_lib = require("resty.openssl.x509.chain") +local stack_lib = require("resty.openssl.stack") +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X +local OPENSSL_10 = require("resty.openssl.version").OPENSSL_10 +local format_error = require("resty.openssl.err").format_error + +local _M = { + SSL_VERIFY_NONE = 0x00, + SSL_VERIFY_PEER = 0x01, + SSL_VERIFY_FAIL_IF_NO_PEER_CERT = 0x02, + SSL_VERIFY_CLIENT_ONCE = 0x04, + SSL_VERIFY_POST_HANDSHAKE = 0x08, +} + +local ops = { + SSL_OP_NO_EXTENDED_MASTER_SECRET = 0x00000001, + SSL_OP_CLEANSE_PLAINTEXT = 0x00000002, + SSL_OP_LEGACY_SERVER_CONNECT = 0x00000004, + SSL_OP_TLSEXT_PADDING = 0x00000010, + SSL_OP_SAFARI_ECDHE_ECDSA_BUG = 0x00000040, + SSL_OP_IGNORE_UNEXPECTED_EOF = 0x00000080, + SSL_OP_DISABLE_TLSEXT_CA_NAMES = 0x00000200, + SSL_OP_ALLOW_NO_DHE_KEX = 0x00000400, + SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS = 0x00000800, + SSL_OP_NO_QUERY_MTU = 0x00001000, + SSL_OP_COOKIE_EXCHANGE = 0x00002000, + SSL_OP_NO_TICKET = 0x00004000, + SSL_OP_CISCO_ANYCONNECT = 0x00008000, + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION = 0x00010000, + SSL_OP_NO_COMPRESSION = 0x00020000, + SSL_OP_ALLOW_UNSAFE_LEGACY_RENEGOTIATION = 0x00040000, + SSL_OP_NO_ENCRYPT_THEN_MAC = 0x00080000, + SSL_OP_ENABLE_MIDDLEBOX_COMPAT = 0x00100000, + SSL_OP_PRIORITIZE_CHACHA = 0x00200000, + SSL_OP_CIPHER_SERVER_PREFERENCE = 0x00400000, + SSL_OP_TLS_ROLLBACK_BUG = 0x00800000, + SSL_OP_NO_ANTI_REPLAY = 0x01000000, + SSL_OP_NO_SSLv3 = 0x02000000, + SSL_OP_NO_TLSv1 = 0x04000000, + SSL_OP_NO_TLSv1_2 = 0x08000000, + SSL_OP_NO_TLSv1_1 = 0x10000000, + SSL_OP_NO_TLSv1_3 = 0x20000000, + SSL_OP_NO_DTLSv1 = 0x04000000, + SSL_OP_NO_DTLSv1_2 = 0x08000000, + SSL_OP_NO_RENEGOTIATION = 0x40000000, + SSL_OP_CRYPTOPRO_TLSEXT_BUG = 0x80000000, +} +ops.SSL_OP_NO_SSL_MASK = ops.SSL_OP_NO_SSLv3 + ops.SSL_OP_NO_TLSv1 + ops.SSL_OP_NO_TLSv1_1 + + ops.SSL_OP_NO_TLSv1_2 + ops.SSL_OP_NO_TLSv1_3 +ops.SSL_OP_NO_DTLS_MASK = ops.SSL_OP_NO_DTLSv1 + ops.SSL_OP_NO_DTLSv1_2 +for k, v in pairs(ops) do + _M[k] = v +end + +local mt = {__index = _M} + +local ssl_ptr_ct = ffi.typeof('SSL*') + +local stack_of_ssl_cipher_iter = function(ctx) + return stack_lib.mt_of("SSL_CIPHER", function(x) return x end, {}, true).__ipairs({ctx = ctx}) +end + +function _M.from_request() + -- don't GC this + local ctx, err = nginx_aux.get_req_ssl() + if err ~= nil then + return nil, err + end + + return setmetatable({ + ctx = ctx, + -- the cdata is not manage by Lua, don't GC on Lua side + _managed = false, + -- this is the client SSL session + _server = true, + }, mt) +end + +function _M.from_socket(socket) + if not socket then + return nil, "expect a ngx.socket.tcp instance at #1" + end + -- don't GC this + local ctx, err = nginx_aux.get_socket_ssl(socket) + if err ~= nil then + return nil, err + end + + return setmetatable({ + ctx = ctx, + -- the cdata is not manage by Lua, don't GC on Lua side + _managed = false, + -- this is the client SSL session + _server = false, + }, mt) +end + +function _M.istype(l) + return l and l.ctx and ffi.istype(ssl_ptr_ct, l.ctx) +end + +function _M:get_peer_certificate() + local x509 + if OPENSSL_3X then + x509 = C.SSL_get1_peer_certificate(self.ctx) + else + x509 = C.SSL_get_peer_certificate(self.ctx) + end + + if x509 == nil then + return nil + end + ffi.gc(x509, C.X509_free) + + local err + -- always copy, although the ref counter of returned x509 is + -- already increased by one. + x509, err = x509_lib.dup(x509) + if err then + return nil, err + end + + return x509 +end + +function _M:get_peer_cert_chain() + local stack = C.SSL_get_peer_cert_chain(self.ctx) + + if stack == nil then + return nil + end + + return chain_lib.dup(stack) +end + +-- TLSv1.3 +function _M:set_ciphersuites(ciphers) + if C.SSL_set_ciphersuites(self.ctx, ciphers) ~= 1 then + return false, format_error("ssl:set_ciphers: SSL_set_ciphersuites") + end + + return true +end + +-- TLSv1.2 and lower +function _M:set_cipher_list(ciphers) + if C.SSL_set_cipher_list(self.ctx, ciphers) ~= 1 then + return false, format_error("ssl:set_ciphers: SSL_set_cipher_list") + end + + return true +end + +function _M:get_ciphers() + local ciphers = C.SSL_get_ciphers(self.ctx) + + if ciphers == nil then + return nil + end + + local ret = {} + + for i, cipher in stack_of_ssl_cipher_iter(ciphers) do + cipher = C.SSL_CIPHER_get_name(cipher) + if cipher == nil then + return nil, format_error("ssl:get_ciphers: SSL_CIPHER_get_name") + end + ret[i] = ffi_str(cipher) + end + + return table.concat(ret, ":") +end + +function _M:get_cipher_name() + local cipher = C.SSL_get_current_cipher(self.ctx) + + if cipher == nil then + return nil + end + + cipher = C.SSL_CIPHER_get_name(cipher) + if cipher == nil then + return nil, format_error("ssl:get_cipher_name: SSL_CIPHER_get_name") + end + return ffi_str(cipher) +end + +function _M:set_timeout(tm) + local session = C.SSL_get_session(self.ctx) + + if session == nil then + return false, format_error("ssl:set_timeout: SSL_get_session") + end + + if C.SSL_SESSION_set_timeout(session, tm) ~= 1 then + return false, format_error("ssl:set_timeout: SSL_SESSION_set_timeout") + end + return true +end + +function _M:get_timeout() + local session = C.SSL_get_session(self.ctx) + + if session == nil then + return false, format_error("ssl:get_timeout: SSL_get_session") + end + + return tonumber(C.SSL_SESSION_get_timeout(session)) +end + +local ssl_verify_default_cb = ffi_cast("verify_callback", function() + return 1 +end) + +function _M:set_verify(mode, cb) + if self._verify_cb then + self._verify_cb:free() + end + + if cb then + cb = ffi_cast("verify_callback", cb) + self._verify_cb = cb + end + + C.SSL_set_verify(self.ctx, mode, cb or ssl_verify_default_cb) + + return true +end + +function _M:free_verify_cb() + if self._verify_cb then + self._verify_cb:free() + self._verify_cb = nil + end +end + +function _M:add_client_ca(x509) + if not self._server then + return false, "ssl:add_client_ca is only supported on server side" + end + + if not x509_lib.istype(x509) then + return false, "expect a x509 instance at #1" + end + + if C.SSL_add_client_CA(self.ctx, x509.ctx) ~= 1 then + return false, format_error("ssl:add_client_ca: SSL_add_client_CA") + end + + return true +end + +function _M:set_options(...) + local bitmask = 0 + for _, opt in ipairs({...}) do + bitmask = bit.bor(bitmask, opt) + end + + if OPENSSL_10 then + bitmask = C.SSL_ctrl(self.ctx, 32, bitmask, nil) -- SSL_CTRL_OPTIONS + else + bitmask = C.SSL_set_options(self.ctx, bitmask) + end + + return tonumber(bitmask) +end + +function _M:get_options(readable) + local bitmask + if OPENSSL_10 then + bitmask = C.SSL_ctrl(self.ctx, 32, 0, nil) -- SSL_CTRL_OPTIONS + else + bitmask = C.SSL_get_options(self.ctx) + end + + if not readable then + return tonumber(bitmask) + end + + local ret = {} + for k, v in pairs(ops) do + if bit.band(v, bitmask) > 0 then + table.insert(ret, k) + end + end + table.sort(ret) + + return ret +end + +function _M:clear_options(...) + local bitmask = 0 + for _, opt in ipairs({...}) do + bitmask = bit.bor(bitmask, opt) + end + + if OPENSSL_10 then + bitmask = C.SSL_ctrl(self.ctx, 77, bitmask, nil) -- SSL_CTRL_CLEAR_OPTIONS + else + bitmask = C.SSL_clear_options(self.ctx, bitmask) + end + + return tonumber(bitmask) +end + +local valid_protocols = { + ["SSLv3"] = ops.SSL_OP_NO_SSLv3, + ["TLSv1"] = ops.SSL_OP_NO_TLSv1, + ["TLSv1.1"] = ops.SSL_OP_NO_TLSv1_1, + ["TLSv1.2"] = ops.SSL_OP_NO_TLSv1_2, + ["TLSv1.3"] = ops.SSL_OP_NO_TLSv1_3, +} +local any_tlsv1 = ops.SSL_OP_NO_TLSv1_1 + ops.SSL_OP_NO_TLSv1_2 + ops.SSL_OP_NO_TLSv1_3 + +function _M:set_protocols(...) + local bitmask = 0 + for _, prot in ipairs({...}) do + local b = valid_protocols[prot] + if not b then + return nil, "\"" .. prot .. "\" is not a valid protocol" + end + bitmask = bit.bor(bitmask, b) + end + + if bit.band(bitmask, any_tlsv1) > 0 then + bitmask = bit.bor(bitmask, ops.SSL_OP_NO_TLSv1) + end + + -- first disable all protocols + if OPENSSL_10 then + C.SSL_ctrl(self.ctx, 32, ops.SSL_OP_NO_SSL_MASK, nil) -- SSL_CTRL_OPTIONS + else + C.SSL_set_options(self.ctx, ops.SSL_OP_NO_SSL_MASK) + end + + -- then enable selected protocols + if OPENSSL_10 then + return tonumber(C.SSL_clear_options(self.ctx, bitmask)) + else + return tonumber(C.SSL_ctrl(self.ctx, 77, bitmask, nil)) -- SSL_CTRL_CLEAR_OPTIONS) + end +end + +return _M
\ No newline at end of file diff --git a/server/resty/openssl/ssl_ctx.lua b/server/resty/openssl/ssl_ctx.lua new file mode 100644 index 0000000..dd110f9 --- /dev/null +++ b/server/resty/openssl/ssl_ctx.lua @@ -0,0 +1,95 @@ +local ffi = require "ffi" +local C = ffi.C +local new_tab = table.new +local char = string.char +local concat = table.concat + +require "resty.openssl.include.ssl" + +local nginx_aux = require("resty.openssl.auxiliary.nginx") + +local _M = {} +local mt = {__index = _M} + +local ssl_ctx_ptr_ct = ffi.typeof('SSL_CTX*') + +function _M.from_request() + -- don't GC this + local ctx, err = nginx_aux.get_req_ssl_ctx() + if err ~= nil then + return nil, err + end + + return setmetatable({ + ctx = ctx, + -- the cdata is not manage by Lua, don't GC on Lua side + _managed = false, + -- this is the Server SSL session + _server = true, + }, mt) +end + +function _M.from_socket(socket) + if not socket then + return nil, "expect a ngx.socket.tcp instance at #1" + end + -- don't GC this + local ctx, err = nginx_aux.get_socket_ssl_ctx(socket) + if err ~= nil then + return nil, err + end + + return setmetatable({ + ctx = ctx, + -- the cdata is not manage by Lua, don't GC on Lua side + _managed = false, + -- this is the client SSL session + _server = false, + }, mt) +end + +function _M.istype(l) + return l and l.ctx and ffi.istype(ssl_ctx_ptr_ct, l.ctx) +end + +local function encode_alpn_wire(alpns) + local ret = new_tab(#alpns*2, 0) + for i, alpn in ipairs(alpns) do + ret[i*2-1] = char(#alpn) + ret[i*2] = alpn + end + + return concat(ret, "") +end + +function _M:set_alpns(alpns) + if not self._server then + return nil, "ssl_ctx:set_alpns is only supported on server side" + end + + alpns = encode_alpn_wire(alpns) + + if self._alpn_select_cb then + self._alpn_select_cb:free() + end + + local alpn_select_cb = ffi.cast("SSL_CTX_alpn_select_cb_func", function(_, out, outlen, client, client_len) + local code = ffi.C.SSL_select_next_proto( + ffi.cast("unsigned char **", out), outlen, + alpns, #alpns, + client, client_len) + if code ~= 1 then -- OPENSSL_NPN_NEGOTIATED + return 3 -- SSL_TLSEXT_ERR_NOACK + end + return 0 -- SSL_TLSEXT_ERR_OK + end) + + C.SSL_CTX_set_alpn_select_cb(self.ctx, alpn_select_cb, nil) + -- store the reference to avoid it being GC'ed + self._alpn_select_cb = alpn_select_cb + + return true +end + + +return _M
\ No newline at end of file diff --git a/server/resty/openssl/stack.lua b/server/resty/openssl/stack.lua new file mode 100644 index 0000000..9bdc377 --- /dev/null +++ b/server/resty/openssl/stack.lua @@ -0,0 +1,159 @@ + +--[[ + The OpenSSL stack library. Note `safestack` is not usable here in ffi because + those symbols are eaten after preprocessing. + Instead, we should do a Lua land type checking by having a nested field indicating + which type of cdata its ctx holds. +]] +local ffi = require "ffi" +local C = ffi.C +local ffi_cast = ffi.cast +local ffi_gc = ffi.gc + +local stack_macro = require "resty.openssl.include.stack" +local format_error = require("resty.openssl.err").format_error + +local _M = {} + +local function gc_of(typ) + local f = C[typ .. "_free"] + return function (st) + stack_macro.OPENSSL_sk_pop_free(st, f) + end +end + +_M.gc_of = gc_of + +_M.mt_of = function(typ, convert, index_tbl, no_gc) + if type(typ) ~= "string" then + error("expect a string at #1") + elseif type(convert) ~= "function" then + error("expect a function at #2") + end + + local typ_ptr = typ .. "*" + + -- starts from 0 + local function value_at(ctx, i) + local elem = stack_macro.OPENSSL_sk_value(ctx, i) + if elem == nil then + error(format_error("OPENSSL_sk_value")) + end + local dup, err = convert(ffi_cast(typ_ptr, elem)) + if err then + error(err) + end + return dup + end + + local function iter(tbl) + if not tbl then error("instance is nil") end + local i = 0 + local n = tonumber(stack_macro.OPENSSL_sk_num(tbl.ctx)) + return function() + i = i + 1 + if i <= n then + return i, value_at(tbl.ctx, i-1) + end + end + end + + local ret = { + __pairs = iter, + __ipairs = iter, + __len = function(tbl) + if not tbl then error("instance is nil") end + return tonumber(stack_macro.OPENSSL_sk_num(tbl.ctx)) + end, + __index = function(tbl, k) + if not tbl then error("instance is nil") end + local i = tonumber(k) + if not i then + return index_tbl[k] + end + local n = stack_macro.OPENSSL_sk_num(tbl.ctx) + if i <= 0 or i > n then + return nil + end + return value_at(tbl.ctx, i-1) + end, + } + + if not no_gc then + ret.__gc = gc_of(typ) + end + return ret +end + +_M.new_of = function(typ) + local gc = gc_of(typ) + return function() + local raw = stack_macro.OPENSSL_sk_new_null() + if raw == nil then + return nil, "stack.new_of: OPENSSL_sk_new_null() failed" + end + ffi_gc(raw, gc) + return raw + end +end + +_M.add_of = function(typ) + local ptr = ffi.typeof(typ .. "*") + return function(stack, ctx) + if not stack then error("instance is nil") end + if ctx == nil or not ffi.istype(ptr, ctx) then + return false, "stack.add_of: expect a " .. typ .. "* at #1" + end + local code = stack_macro.OPENSSL_sk_push(stack, ctx) + if code == 0 then + return false, "stack.add_of: OPENSSL_sk_push() failed" + end + return true + end +end + +local stack_ptr_ct = ffi.typeof("OPENSSL_STACK*") +_M.dup_of = function(_) + return function(ctx) + if ctx == nil or not ffi.istype(stack_ptr_ct, ctx) then + return nil, "stack.dup_of: expect a stack ctx at #1" + end + local ctx = stack_macro.OPENSSL_sk_dup(ctx) + if ctx == nil then + return nil, "stack.dup_of: OPENSSL_sk_dup() failed" + end + -- if the stack is duplicated: since we don't copy the elements + -- then we only control gc of the stack itself here + ffi_gc(ctx, stack_macro.OPENSSL_sk_free) + return ctx + end +end + +-- fallback function to iterate if LUAJIT_ENABLE_LUA52COMPAT not enabled +_M.all_func = function(mt) + return function(stack) + if not stack then error("stack is nil") end + local ret = {} + local _next = mt.__pairs(stack) + while true do + local i, elem = _next() + if elem then + ret[i] = elem + else + break + end + end + return ret + end +end + +_M.deep_copy_of = function(typ) + local dup = C[typ .. "_dup"] + local free = C[typ .. "_free"] + + return function(ctx) + return stack_macro.OPENSSL_sk_deep_copy(ctx, dup, free) + end +end + +return _M
\ No newline at end of file diff --git a/server/resty/openssl/version.lua b/server/resty/openssl/version.lua new file mode 100644 index 0000000..f982b61 --- /dev/null +++ b/server/resty/openssl/version.lua @@ -0,0 +1,117 @@ +-- https://github.com/GUI/lua-openssl-ffi/blob/master/lib/openssl-ffi/version.lua +local ffi = require "ffi" +local C = ffi.C +local ffi_str = ffi.string + +ffi.cdef[[ + // 1.0 + unsigned long SSLeay(void); + const char *SSLeay_version(int t); + // >= 1.1 + unsigned long OpenSSL_version_num(); + const char *OpenSSL_version(int t); + // >= 3.0 + const char *OPENSSL_info(int t); + // BoringSSL + int BORINGSSL_self_test(void); +]] + +local version_func, info_func +local types_table + +-- >= 1.1 +local ok, version_num = pcall(function() + local num = C.OpenSSL_version_num() + version_func = C.OpenSSL_version + types_table = { + VERSION = 0, + CFLAGS = 1, + BUILT_ON = 2, + PLATFORM = 3, + DIR = 4, + ENGINES_DIR = 5, + VERSION_STRING = 6, + FULL_VERSION_STRING = 7, + MODULES_DIR = 8, + CPU_INFO = 9, + } + return num +end) + + +if not ok then + -- 1.0.x + ok, version_num = pcall(function() + local num = C.SSLeay() + version_func = C.SSLeay_version + types_table = { + VERSION = 0, + CFLAGS = 2, + BUILT_ON = 3, + PLATFORM = 4, + DIR = 5, + } + return num + end) +end + + +if not ok then + error(string.format("OpenSSL has encountered an error: %s; is OpenSSL library loaded?", + tostring(version_num))) +elseif type(version_num) == 'number' and version_num < 0x10000000 then + error(string.format("OpenSSL version %s is not supported", tostring(version_num or 0))) +elseif not version_num then + error("Can not get OpenSSL version") +end + +if version_num >= 0x30000000 then + local info_table = { + INFO_CONFIG_DIR = 1001, + INFO_ENGINES_DIR = 1002, + INFO_MODULES_DIR = 1003, + INFO_DSO_EXTENSION = 1004, + INFO_DIR_FILENAME_SEPARATOR = 1005, + INFO_LIST_SEPARATOR = 1006, + INFO_SEED_SOURCE = 1007, + INFO_CPU_SETTINGS = 1008, + } + + for k, v in pairs(info_table) do + types_table[k] = v + end + + info_func = C.OPENSSL_info +else + info_func = function(_) + error(string.format("OPENSSL_info is not supported on %s", ffi_str(version_func(0)))) + end +end + +local BORINGSSL = false +pcall(function() + local _ = C.BORINGSSL_self_test + BORINGSSL = true +end) + +return setmetatable({ + version_num = tonumber(version_num), + version_text = ffi_str(version_func(0)), + version = function(t) + return ffi_str(version_func(t)) + end, + info = function(t) + return ffi_str(info_func(t)) + end, + OPENSSL_3X = version_num >= 0x30000000 and version_num < 0x30200000, + OPENSSL_30 = version_num >= 0x30000000 and version_num < 0x30100000, -- for backward compat, deprecated + OPENSSL_11 = version_num >= 0x10100000 and version_num < 0x10200000, + OPENSSL_111 = version_num >= 0x10101000 and version_num < 0x10200000, + OPENSSL_11_OR_LATER = version_num >= 0x10100000 and version_num < 0x30200000, + OPENSSL_111_OR_LATER = version_num >= 0x10101000 and version_num < 0x30200000, + OPENSSL_10 = version_num < 0x10100000 and version_num > 0x10000000, + BORINGSSL = BORINGSSL, + BORINGSSL_110 = BORINGSSL and version_num >= 0x10100000 and version_num < 0x10101000 + }, { + __index = types_table, +})
\ No newline at end of file diff --git a/server/resty/openssl/x509/altname.lua b/server/resty/openssl/x509/altname.lua new file mode 100644 index 0000000..34bf9e0 --- /dev/null +++ b/server/resty/openssl/x509/altname.lua @@ -0,0 +1,248 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc +local ffi_cast = ffi.cast +local ffi_str = ffi.string + +require "resty.openssl.include.x509" +require "resty.openssl.include.x509v3" +local asn1_macro = require "resty.openssl.include.asn1" +local stack_lib = require "resty.openssl.stack" +local name_lib = require "resty.openssl.x509.name" +local altname_macro = require "resty.openssl.include.x509.altname" + +local _M = {} + +local general_names_ptr_ct = ffi.typeof("GENERAL_NAMES*") + +local STACK = "GENERAL_NAME" +local new = stack_lib.new_of(STACK) +local add = stack_lib.add_of(STACK) +local dup = stack_lib.dup_of(STACK) + +local types = altname_macro.types + +local AF_INET = 2 +local AF_INET6 = 10 +if ffi.os == "OSX" then + AF_INET6 = 30 +elseif ffi.os == "BSD" then + AF_INET6 = 28 +elseif ffi.os == "Windows" then + AF_INET6 = 23 +end + +ffi.cdef [[ + typedef int socklen_t; + int inet_pton(int af, const char *restrict src, void *restrict dst); + const char *inet_ntop(int af, const void *restrict src, + char *restrict dst, socklen_t size); +]] + +local ip_buffer = ffi.new("unsigned char [46]") -- 46 bytes enough for both string ipv6 and binary ipv6 + +-- similar to GENERAL_NAME_print, but returns value instead of print +local gn_decode = function(ctx) + local typ = ctx.type + local k = altname_macro.literals[typ] + local v + if typ == types.OtherName then + v = "OtherName:<unsupported>" + elseif typ == types.RFC822Name then + v = ffi_str(asn1_macro.ASN1_STRING_get0_data(ctx.d.rfc822Name)) + elseif typ == types.DNS then + v = ffi_str(asn1_macro.ASN1_STRING_get0_data(ctx.d.dNSName)) + elseif typ == types.X400 then + v = "X400:<unsupported>" + elseif typ == types.DirName then + v = name_lib.dup(ctx.d.directoryName) + elseif typ == types.EdiParty then + v = "EdiParty:<unsupported>" + elseif typ == types.URI then + v = ffi_str(asn1_macro.ASN1_STRING_get0_data(ctx.d.uniformResourceIdentifier)) + elseif typ == types.IP then + v = asn1_macro.ASN1_STRING_get0_data(ctx.d.iPAddress) + local l = tonumber(C.ASN1_STRING_length(ctx.d.iPAddress)) + if l ~= 4 and l ~= 16 then + error("Unknown IP address type") + end + v = C.inet_ntop(l == 4 and AF_INET or AF_INET6, v, ip_buffer, 46) + v = ffi_str(v) + elseif typ == types.RID then + v = "RID:<unsupported>" + else + error("unknown type" .. typ .. "-> " .. types.OtherName) + end + return { k, v } +end + +-- shared with info_access +_M.gn_decode = gn_decode + +local mt = stack_lib.mt_of(STACK, gn_decode, _M) +local mt__pairs = mt.__pairs +mt.__pairs = function(tbl) + local f = mt__pairs(tbl) + return function() + local _, e = f() + if not e then return end + return unpack(e) + end +end + +function _M.new() + local ctx = new() + if ctx == nil then + return nil, "x509.altname.new: OPENSSL_sk_new_null() failed" + end + local cast = ffi_cast("GENERAL_NAMES*", ctx) + + local self = setmetatable({ + ctx = ctx, + cast = cast, + _is_shallow_copy = false, + }, mt) + + return self, nil +end + +function _M.istype(l) + return l and l.cast and ffi.istype(general_names_ptr_ct, l.cast) +end + +function _M.dup(ctx) + if ctx == nil or not ffi.istype(general_names_ptr_ct, ctx) then + return nil, "x509.altname.dup: expect a GENERAL_NAMES* ctx at #1" + end + + local dup_ctx = dup(ctx) + + return setmetatable({ + cast = ffi_cast("GENERAL_NAMES*", dup_ctx), + ctx = dup_ctx, + -- don't let lua gc the original stack to keep its elements + _dupped_from = ctx, + _is_shallow_copy = true, + _elem_refs = {}, + _elem_refs_idx = 1, + }, mt), nil +end + +local function gn_set(gn, typ, value) + if type(typ) ~= 'string' then + return "x509.altname:gn_set: expect a string at #1" + end + local typ_lower = typ:lower() + if type(value) ~= 'string' then + return "x509.altname:gn_set: except a string at #2" + end + + local txt = value + local gn_type = types[typ_lower] + + if not gn_type then + return "x509.altname:gn_set: unknown type " .. typ + end + + if gn_type == types.IP then + if C.inet_pton(AF_INET, txt, ip_buffer) == 1 then + txt = ffi_str(ip_buffer, 4) + elseif C.inet_pton(AF_INET6, txt, ip_buffer) == 1 then + txt = ffi_str(ip_buffer, 16) + else + return "x509.altname:gn_set: invalid IP address " .. txt + end + + elseif gn_type ~= types.Email and + gn_type ~= types.URI and + gn_type ~= types.DNS then + return "x509.altname:gn_set: setting type " .. typ .. " is currently not supported" + end + + gn.type = gn_type + + local asn1_string = C.ASN1_IA5STRING_new() + if asn1_string == nil then + return "x509.altname:gn_set: ASN1_STRING_type_new() failed" + end + + local code = C.ASN1_STRING_set(asn1_string, txt, #txt) + if code ~= 1 then + C.ASN1_STRING_free(asn1_string) + return "x509.altname:gn_set: ASN1_STRING_set() failed: " .. code + end + gn.d.ia5 = asn1_string +end + +-- shared with info_access +_M.gn_set = gn_set + +function _M:add(typ, value) + + -- the stack element stays with stack + -- we shouldn't add gc handler if it's already been + -- pushed to stack. instead, rely on the gc handler + -- of the stack to release all memories + local gn = C.GENERAL_NAME_new() + if gn == nil then + return nil, "x509.altname:add: GENERAL_NAME_new() failed" + end + + local err = gn_set(gn, typ, value) + if err then + C.GENERAL_NAME_free(gn) + return nil, err + end + + local _, err = add(self.ctx, gn) + if err then + C.GENERAL_NAME_free(gn) + return nil, err + end + + -- if the stack is duplicated, the gc handler is not pop_free + -- handle the gc by ourselves + if self._is_shallow_copy then + ffi_gc(gn, C.GENERAL_NAME_free) + self._elem_refs[self._elem_refs_idx] = gn + self._elem_refs_idx = self._elem_refs_idx + 1 + end + return self +end + +_M.all = function(self) + local ret = {} + local _next = mt.__pairs(self) + while true do + local k, v = _next() + if k then + ret[k] = v + else + break + end + end + return ret +end + +_M.each = mt.__pairs +_M.index = mt.__index +_M.count = mt.__len + +mt.__tostring = function(self) + local values = {} + local _next = mt.__pairs(self) + while true do + local k, v = _next() + if k then + table.insert(values, k .. "=" .. v) + else + break + end + end + table.sort(values) + return table.concat(values, "/") +end + +_M.tostring = mt.__tostring + +return _M diff --git a/server/resty/openssl/x509/chain.lua b/server/resty/openssl/x509/chain.lua new file mode 100644 index 0000000..5557ea0 --- /dev/null +++ b/server/resty/openssl/x509/chain.lua @@ -0,0 +1,76 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc + +local stack_lib = require "resty.openssl.stack" +local x509_lib = require "resty.openssl.x509" +local format_error = require("resty.openssl.err").format_error + +local _M = {} + +local stack_ptr_ct = ffi.typeof("OPENSSL_STACK*") + +local STACK = "X509" +local gc = stack_lib.gc_of(STACK) +local new = stack_lib.new_of(STACK) +local add = stack_lib.add_of(STACK) +local mt = stack_lib.mt_of(STACK, x509_lib.dup, _M) + +function _M.new() + local raw = new() + + local self = setmetatable({ + stack_of = STACK, + ctx = raw, + }, mt) + + return self, nil +end + +function _M.istype(l) + return l and l.ctx and ffi.istype(stack_ptr_ct, l.ctx) + and l.stack_of and l.stack_of == STACK +end + +function _M.dup(ctx) + if ctx == nil or not ffi.istype(stack_ptr_ct, ctx) then + return nil, "x509.chain.dup: expect a stack ctx at #1, got " .. type(ctx) + end + -- sk_X509_dup plus up ref for each X509 element + local ctx = C.X509_chain_up_ref(ctx) + if ctx == nil then + return nil, "x509.chain.dup: X509_chain_up_ref() failed" + end + ffi_gc(ctx, gc) + + return setmetatable({ + stack_of = STACK, + ctx = ctx, + }, mt) +end + +function _M:add(x509) + if not x509_lib.istype(x509) then + return nil, "x509.chain:add: expect a x509 instance at #1" + end + + local dup = C.X509_dup(x509.ctx) + if dup == nil then + return nil, format_error("x509.chain:add: X509_dup") + end + + local _, err = add(self.ctx, dup) + if err then + C.X509_free(dup) + return nil, err + end + + return true +end + +_M.all = stack_lib.all_func(mt) +_M.each = mt.__ipairs +_M.index = mt.__index +_M.count = mt.__len + +return _M diff --git a/server/resty/openssl/x509/crl.lua b/server/resty/openssl/x509/crl.lua new file mode 100644 index 0000000..3ee4501 --- /dev/null +++ b/server/resty/openssl/x509/crl.lua @@ -0,0 +1,607 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc + +require "resty.openssl.include.x509.crl" +require "resty.openssl.include.pem" +require "resty.openssl.include.x509v3" +local asn1_lib = require("resty.openssl.asn1") +local bn_lib = require("resty.openssl.bn") +local revoked_lib = require("resty.openssl.x509.revoked") +local digest_lib = require("resty.openssl.digest") +local extension_lib = require("resty.openssl.x509.extension") +local pkey_lib = require("resty.openssl.pkey") +local bio_util = require "resty.openssl.auxiliary.bio" +local ctx_lib = require "resty.openssl.ctx" +local stack_lib = require "resty.openssl.stack" +local txtnid2nid = require("resty.openssl.objects").txtnid2nid +local find_sigid_algs = require("resty.openssl.objects").find_sigid_algs +local format_error = require("resty.openssl.err").format_error +local version = require("resty.openssl.version") +local OPENSSL_10 = version.OPENSSL_10 +local OPENSSL_11_OR_LATER = version.OPENSSL_11_OR_LATER +local OPENSSL_3X = version.OPENSSL_3X +local BORINGSSL = version.BORINGSSL +local BORINGSSL_110 = version.BORINGSSL_110 -- used in boringssl-fips-20190808 + +local accessors = {} + +accessors.set_issuer_name = C.X509_CRL_set_issuer_name +accessors.set_version = C.X509_CRL_set_version + + +if OPENSSL_11_OR_LATER and not BORINGSSL_110 then + accessors.get_last_update = C.X509_CRL_get0_lastUpdate + accessors.set_last_update = C.X509_CRL_set1_lastUpdate + accessors.get_next_update = C.X509_CRL_get0_nextUpdate + accessors.set_next_update = C.X509_CRL_set1_nextUpdate + accessors.get_version = C.X509_CRL_get_version + accessors.get_issuer_name = C.X509_CRL_get_issuer -- returns internal ptr + accessors.get_signature_nid = C.X509_CRL_get_signature_nid + -- BORINGSSL_110 exports X509_CRL_get_signature_nid, but just ignored for simplicity + accessors.get_revoked = C.X509_CRL_get_REVOKED +elseif OPENSSL_10 or BORINGSSL_110 then + accessors.get_last_update = function(crl) + if crl == nil or crl.crl == nil then + return nil + end + return crl.crl.lastUpdate + end + accessors.set_last_update = C.X509_CRL_set_lastUpdate + accessors.get_next_update = function(crl) + if crl == nil or crl.crl == nil then + return nil + end + return crl.crl.nextUpdate + end + accessors.set_next_update = C.X509_CRL_set_nextUpdate + accessors.get_version = function(crl) + if crl == nil or crl.crl == nil then + return nil + end + return C.ASN1_INTEGER_get(crl.crl.version) + end + accessors.get_issuer_name = function(crl) + if crl == nil or crl.crl == nil then + return nil + end + return crl.crl.issuer + end + accessors.get_signature_nid = function(crl) + if crl == nil or crl.crl == nil or crl.crl.sig_alg == nil then + return nil + end + return C.OBJ_obj2nid(crl.crl.sig_alg.algorithm) + end + accessors.get_revoked = function(crl) + return crl.crl.revoked + end +end + +local function __tostring(self, fmt) + if not fmt or fmt == 'PEM' then + return bio_util.read_wrap(C.PEM_write_bio_X509_CRL, self.ctx) + elseif fmt == 'DER' then + return bio_util.read_wrap(C.i2d_X509_CRL_bio, self.ctx) + else + return nil, "x509.crl:tostring: can only write PEM or DER format, not " .. fmt + end +end + +local _M = {} +local mt = { __index = _M, __tostring = __tostring } + +local x509_crl_ptr_ct = ffi.typeof("X509_CRL*") + +function _M.new(crl, fmt, properties) + local ctx + if not crl then + if OPENSSL_3X then + ctx = C.X509_CRL_new_ex(ctx_lib.get_libctx(), properties) + else + ctx = C.X509_CRL_new() + end + if ctx == nil then + return nil, "x509.crl.new: X509_CRL_new() failed" + end + elseif type(crl) == "string" then + -- routine for load an existing csr + local bio = C.BIO_new_mem_buf(crl, #crl) + if bio == nil then + return nil, format_error("x509.crl.new: BIO_new_mem_buf") + end + + fmt = fmt or "*" + while true do -- luacheck: ignore 512 -- loop is executed at most once + if fmt == "PEM" or fmt == "*" then + ctx = C.PEM_read_bio_X509_CRL(bio, nil, nil, nil) + if ctx ~= nil then + break + elseif fmt == "*" then + -- BIO_reset; #define BIO_CTRL_RESET 1 + local code = C.BIO_ctrl(bio, 1, 0, nil) + if code ~= 1 then + return nil, "x509.crl.new: BIO_ctrl() failed: " .. code + end + end + end + if fmt == "DER" or fmt == "*" then + ctx = C.d2i_X509_CRL_bio(bio, nil) + end + break + end + C.BIO_free(bio) + if ctx == nil then + return nil, format_error("x509.crl.new") + end + -- clear errors occur when trying + C.ERR_clear_error() + else + return nil, "x509.crl.new: expect nil or a string at #1" + end + ffi_gc(ctx, C.X509_CRL_free) + + local self = setmetatable({ + ctx = ctx, + }, mt) + + return self, nil +end + +function _M.istype(l) + return l and l and l.ctx and ffi.istype(x509_crl_ptr_ct, l.ctx) +end + +function _M.dup(ctx) + if not ffi.istype(x509_crl_ptr_ct, ctx) then + return nil, "x509.crl.dup: expect a x509.crl ctx at #1" + end + local ctx = C.X509_CRL_dup(ctx) + if ctx == nil then + return nil, "x509.crl.dup: X509_CRL_dup() failed" + end + + ffi_gc(ctx, C.X509_CRL_free) + + local self = setmetatable({ + ctx = ctx, + }, mt) + + return self, nil +end + +function _M:tostring(fmt) + return __tostring(self, fmt) +end + +function _M:to_PEM() + return __tostring(self, "PEM") +end + +function _M:text() + return bio_util.read_wrap(C.X509_CRL_print, self.ctx) +end + +local function revoked_decode(ctx) + if OPENSSL_10 then + error("x509.crl:revoked_decode: not supported on OpenSSL 1.0") + end + + local ret = {} + local serial = C.X509_REVOKED_get0_serialNumber(ctx) + if serial ~= nil then + serial = C.ASN1_INTEGER_to_BN(serial, nil) + if serial == nil then + error("x509.crl:revoked_decode: ASN1_INTEGER_to_BN() failed") + end + ffi_gc(serial, C.BN_free) + ret["serial_number"] = bn_lib.to_hex({ctx = serial}) + end + + local date = C.X509_REVOKED_get0_revocationDate(ctx) + if date ~= nil then + date = asn1_lib.asn1_to_unix(date) + ret["revocation_date"] = date + end + + return ret +end + +local revoked_mt = stack_lib.mt_of("X509_REVOKED", revoked_decode, _M) + +local function nil_iter() return nil end +local function revoked_iter(self) + local stack = accessors.get_revoked(self.ctx) + if stack == nil then + return nil_iter + end + + return revoked_mt.__ipairs({ctx = stack}) +end + +mt.__pairs = revoked_iter +mt.__ipairs = revoked_iter +mt.__index = function(self, k) + local i = tonumber(k) + if not i then + return _M[k] + end + + local stack = accessors.get_revoked(self.ctx) + if stack == nil then + return nil + end + + return revoked_mt.__index({ctx = stack}, i) +end +mt.__len = function(self) + local stack = accessors.get_revoked(self.ctx) + if stack == nil then + return 0 + end + + return revoked_mt.__len({ctx = stack}) +end + +_M.all = function(self) + local ret = {} + local _next = mt.__pairs(self) + while true do + local k, v = _next() + if k then + ret[k] = v + else + break + end + end + return ret +end +_M.each = mt.__pairs +_M.index = mt.__index +_M.count = mt.__len + +--- Adds revoked item to stack of revoked certificates of crl +-- @tparam table Instance of crl module +-- @tparam table Instance of revoked module +-- @treturn boolean true if revoked item was successfully added or false otherwise +-- @treturn[opt] string Returns optional error message in case of error +function _M:add_revoked(revoked) + if not revoked_lib.istype(revoked) then + return false, "x509.crl:add_revoked: expect a revoked instance at #1" + end + local ctx = C.X509_REVOKED_dup(revoked.ctx) + if ctx == nil then + return nil, "x509.crl:add_revoked: X509_REVOKED_dup() failed" + end + + if C.X509_CRL_add0_revoked(self.ctx, ctx) == 0 then + return false, format_error("x509.crl:add_revoked") + end + + return true +end + +local ptr_ptr_of_x509_revoked = ffi.typeof("X509_REVOKED*[1]") +function _M:get_by_serial(sn) + local bn, err + if bn_lib.istype(sn) then + bn = sn + elseif type(sn) == "string" then + bn, err = bn_lib.from_hex(sn) + if err then + return nil, "x509.crl:find: can't decode bn: " .. err + end + else + return nil, "x509.crl:find: expect a bn instance at #1" + end + + local sn_asn1 = C.BN_to_ASN1_INTEGER(bn.ctx, nil) + if sn_asn1 == nil then + return nil, "x509.crl:find: BN_to_ASN1_INTEGER() failed" + end + ffi_gc(sn_asn1, C.ASN1_INTEGER_free) + + local pp = ptr_ptr_of_x509_revoked() + local code = C.X509_CRL_get0_by_serial(self.ctx, pp, sn_asn1) + if code == 1 then + return revoked_decode(pp[0]) + elseif code == 2 then + return nil, "not revoked (removeFromCRL)" + end + + -- 0 or other + return nil +end + + +-- START AUTO GENERATED CODE + +-- AUTO GENERATED +function _M:sign(pkey, digest) + if not pkey_lib.istype(pkey) then + return false, "x509.crl:sign: expect a pkey instance at #1" + end + + local digest_algo + if digest then + if not digest_lib.istype(digest) then + return false, "x509.crl:sign: expect a digest instance at #2" + elseif not digest.algo then + return false, "x509.crl:sign: expect a digest instance to have algo member" + end + digest_algo = digest.algo + elseif BORINGSSL then + digest_algo = C.EVP_get_digestbyname('sha256') + end + + -- returns size of signature if success + if C.X509_CRL_sign(self.ctx, pkey.ctx, digest_algo) == 0 then + return false, format_error("x509.crl:sign") + end + + return true +end + +-- AUTO GENERATED +function _M:verify(pkey) + if not pkey_lib.istype(pkey) then + return false, "x509.crl:verify: expect a pkey instance at #1" + end + + local code = C.X509_CRL_verify(self.ctx, pkey.ctx) + if code == 1 then + return true + elseif code == 0 then + return false + else -- typically -1 + return false, format_error("x509.crl:verify", code) + end +end + +-- AUTO GENERATED +local function get_extension(ctx, nid_txt, last_pos) + last_pos = (last_pos or 0) - 1 + local nid, err = txtnid2nid(nid_txt) + if err then + return nil, nil, err + end + local pos = C.X509_CRL_get_ext_by_NID(ctx, nid, last_pos) + if pos == -1 then + return nil + end + local ctx = C.X509_CRL_get_ext(ctx, pos) + if ctx == nil then + return nil, nil, format_error() + end + return ctx, pos +end + +-- AUTO GENERATED +function _M:add_extension(extension) + if not extension_lib.istype(extension) then + return false, "x509.crl:add_extension: expect a x509.extension instance at #1" + end + + -- X509_CRL_add_ext returnes the stack on success, and NULL on error + -- the X509_EXTENSION ctx is dupped internally + if C.X509_CRL_add_ext(self.ctx, extension.ctx, -1) == nil then + return false, format_error("x509.crl:add_extension") + end + + return true +end + +-- AUTO GENERATED +function _M:get_extension(nid_txt, last_pos) + local ctx, pos, err = get_extension(self.ctx, nid_txt, last_pos) + if err then + return nil, nil, "x509.crl:get_extension: " .. err + end + local ext, err = extension_lib.dup(ctx) + if err then + return nil, nil, "x509.crl:get_extension: " .. err + end + return ext, pos+1 +end + +local X509_CRL_delete_ext +if OPENSSL_11_OR_LATER then + X509_CRL_delete_ext = C.X509_CRL_delete_ext +elseif OPENSSL_10 then + X509_CRL_delete_ext = function(ctx, pos) + return C.X509v3_delete_ext(ctx.crl.extensions, pos) + end +else + X509_CRL_delete_ext = function(...) + error("X509_CRL_delete_ext undefined") + end +end + +-- AUTO GENERATED +function _M:set_extension(extension, last_pos) + if not extension_lib.istype(extension) then + return false, "x509.crl:set_extension: expect a x509.extension instance at #1" + end + + last_pos = (last_pos or 0) - 1 + + local nid = extension:get_object().nid + local pos = C.X509_CRL_get_ext_by_NID(self.ctx, nid, last_pos) + -- pos may be -1, which means not found, it's fine, we will add new one instead of replace + + local removed = X509_CRL_delete_ext(self.ctx, pos) + C.X509_EXTENSION_free(removed) + + if C.X509_CRL_add_ext(self.ctx, extension.ctx, pos) == nil then + return false, format_error("x509.crl:set_extension") + end + + return true +end + +-- AUTO GENERATED +function _M:set_extension_critical(nid_txt, crit, last_pos) + local ctx, _, err = get_extension(self.ctx, nid_txt, last_pos) + if err then + return nil, "x509.crl:set_extension_critical: " .. err + end + + if C.X509_EXTENSION_set_critical(ctx, crit and 1 or 0) ~= 1 then + return false, format_error("x509.crl:set_extension_critical") + end + + return true +end + +-- AUTO GENERATED +function _M:get_extension_critical(nid_txt, last_pos) + local ctx, _, err = get_extension(self.ctx, nid_txt, last_pos) + if err then + return nil, "x509.crl:get_extension_critical: " .. err + end + + return C.X509_EXTENSION_get_critical(ctx) == 1 +end + +-- AUTO GENERATED +function _M:get_issuer_name() + local got = accessors.get_issuer_name(self.ctx) + if got == nil then + return nil + end + local lib = require("resty.openssl.x509.name") + -- the internal ptr is returned, ie we need to copy it + return lib.dup(got) +end + +-- AUTO GENERATED +function _M:set_issuer_name(toset) + local lib = require("resty.openssl.x509.name") + if lib.istype and not lib.istype(toset) then + return false, "x509.crl:set_issuer_name: expect a x509.name instance at #1" + end + toset = toset.ctx + if accessors.set_issuer_name(self.ctx, toset) == 0 then + return false, format_error("x509.crl:set_issuer_name") + end + return true +end + +-- AUTO GENERATED +function _M:get_last_update() + local got = accessors.get_last_update(self.ctx) + if got == nil then + return nil + end + + got = asn1_lib.asn1_to_unix(got) + + return got +end + +-- AUTO GENERATED +function _M:set_last_update(toset) + if type(toset) ~= "number" then + return false, "x509.crl:set_last_update: expect a number at #1" + end + + toset = C.ASN1_TIME_set(nil, toset) + ffi_gc(toset, C.ASN1_STRING_free) + + if accessors.set_last_update(self.ctx, toset) == 0 then + return false, format_error("x509.crl:set_last_update") + end + return true +end + +-- AUTO GENERATED +function _M:get_next_update() + local got = accessors.get_next_update(self.ctx) + if got == nil then + return nil + end + + got = asn1_lib.asn1_to_unix(got) + + return got +end + +-- AUTO GENERATED +function _M:set_next_update(toset) + if type(toset) ~= "number" then + return false, "x509.crl:set_next_update: expect a number at #1" + end + + toset = C.ASN1_TIME_set(nil, toset) + ffi_gc(toset, C.ASN1_STRING_free) + + if accessors.set_next_update(self.ctx, toset) == 0 then + return false, format_error("x509.crl:set_next_update") + end + return true +end + +-- AUTO GENERATED +function _M:get_version() + local got = accessors.get_version(self.ctx) + if got == nil then + return nil + end + + got = tonumber(got) + 1 + + return got +end + +-- AUTO GENERATED +function _M:set_version(toset) + if type(toset) ~= "number" then + return false, "x509.crl:set_version: expect a number at #1" + end + + -- Note: this is defined by standards (X.509 et al) to be one less than the certificate version. + -- So a version 3 certificate will return 2 and a version 1 certificate will return 0. + toset = toset - 1 + + if accessors.set_version(self.ctx, toset) == 0 then + return false, format_error("x509.crl:set_version") + end + return true +end + + +-- AUTO GENERATED +function _M:get_signature_nid() + local nid = accessors.get_signature_nid(self.ctx) + if nid <= 0 then + return nil, format_error("x509.crl:get_signature_nid") + end + + return nid +end + +-- AUTO GENERATED +function _M:get_signature_name() + local nid = accessors.get_signature_nid(self.ctx) + if nid <= 0 then + return nil, format_error("x509.crl:get_signature_name") + end + + return ffi.string(C.OBJ_nid2sn(nid)) +end + +-- AUTO GENERATED +function _M:get_signature_digest_name() + local nid = accessors.get_signature_nid(self.ctx) + if nid <= 0 then + return nil, format_error("x509.crl:get_signature_digest_name") + end + + local nid = find_sigid_algs(nid) + + return ffi.string(C.OBJ_nid2sn(nid)) +end +-- END AUTO GENERATED CODE + +return _M + diff --git a/server/resty/openssl/x509/csr.lua b/server/resty/openssl/x509/csr.lua new file mode 100644 index 0000000..08c4860 --- /dev/null +++ b/server/resty/openssl/x509/csr.lua @@ -0,0 +1,531 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc +local ffi_cast = ffi.cast + +require "resty.openssl.include.pem" +require "resty.openssl.include.x509v3" +require "resty.openssl.include.x509.csr" +require "resty.openssl.include.asn1" +local stack_macro = require "resty.openssl.include.stack" +local stack_lib = require "resty.openssl.stack" +local pkey_lib = require "resty.openssl.pkey" +local digest_lib = require("resty.openssl.digest") +local extension_lib = require("resty.openssl.x509.extension") +local extensions_lib = require("resty.openssl.x509.extensions") +local bio_util = require "resty.openssl.auxiliary.bio" +local ctypes = require "resty.openssl.auxiliary.ctypes" +local ctx_lib = require "resty.openssl.ctx" +local txtnid2nid = require("resty.openssl.objects").txtnid2nid +local find_sigid_algs = require("resty.openssl.objects").find_sigid_algs +local format_error = require("resty.openssl.err").format_error +local version = require("resty.openssl.version") +local OPENSSL_10 = version.OPENSSL_10 +local OPENSSL_11_OR_LATER = version.OPENSSL_11_OR_LATER +local OPENSSL_3X = version.OPENSSL_3X +local BORINGSSL = version.BORINGSSL +local BORINGSSL_110 = version.BORINGSSL_110 -- used in boringssl-fips-20190808 + +local accessors = {} + +accessors.set_subject_name = C.X509_REQ_set_subject_name +accessors.get_pubkey = C.X509_REQ_get_pubkey +accessors.set_pubkey = C.X509_REQ_set_pubkey +accessors.set_version = C.X509_REQ_set_version + +if OPENSSL_11_OR_LATER or BORINGSSL_110 then + accessors.get_signature_nid = C.X509_REQ_get_signature_nid +elseif OPENSSL_10 then + accessors.get_signature_nid = function(csr) + if csr == nil or csr.sig_alg == nil then + return nil + end + return C.OBJ_obj2nid(csr.sig_alg.algorithm) + end +end + +if OPENSSL_11_OR_LATER and not BORINGSSL_110 then + accessors.get_subject_name = C.X509_REQ_get_subject_name -- returns internal ptr + accessors.get_version = C.X509_REQ_get_version +elseif OPENSSL_10 or BORINGSSL_110 then + accessors.get_subject_name = function(csr) + if csr == nil or csr.req_info == nil then + return nil + end + return csr.req_info.subject + end + accessors.get_version = function(csr) + if csr == nil or csr.req_info == nil then + return nil + end + return C.ASN1_INTEGER_get(csr.req_info.version) + end +end + +local function __tostring(self, fmt) + if not fmt or fmt == 'PEM' then + return bio_util.read_wrap(C.PEM_write_bio_X509_REQ, self.ctx) + elseif fmt == 'DER' then + return bio_util.read_wrap(C.i2d_X509_REQ_bio, self.ctx) + else + return nil, "x509.csr:tostring: can only write PEM or DER format, not " .. fmt + end +end + +local _M = {} +local mt = { __index = _M, __tostring = __tostring } + +local x509_req_ptr_ct = ffi.typeof("X509_REQ*") + +local stack_ptr_type = ffi.typeof("struct stack_st *[1]") +local x509_extensions_gc = stack_lib.gc_of("X509_EXTENSION") + +function _M.new(csr, fmt, properties) + local ctx + if not csr then + if OPENSSL_3X then + ctx = C.X509_REQ_new_ex(ctx_lib.get_libctx(), properties) + else + ctx = C.X509_REQ_new() + end + if ctx == nil then + return nil, "x509.csr.new: X509_REQ_new() failed" + end + elseif type(csr) == "string" then + -- routine for load an existing csr + local bio = C.BIO_new_mem_buf(csr, #csr) + if bio == nil then + return nil, format_error("x509.csr.new: BIO_new_mem_buf") + end + + fmt = fmt or "*" + while true do -- luacheck: ignore 512 -- loop is executed at most once + if fmt == "PEM" or fmt == "*" then + ctx = C.PEM_read_bio_X509_REQ(bio, nil, nil, nil) + if ctx ~= nil then + break + elseif fmt == "*" then + -- BIO_reset; #define BIO_CTRL_RESET 1 + local code = C.BIO_ctrl(bio, 1, 0, nil) + if code ~= 1 then + return nil, "x509.csr.new: BIO_ctrl() failed: " .. code + end + end + end + if fmt == "DER" or fmt == "*" then + ctx = C.d2i_X509_REQ_bio(bio, nil) + end + break + end + C.BIO_free(bio) + if ctx == nil then + return nil, format_error("x509.csr.new") + end + -- clear errors occur when trying + C.ERR_clear_error() + else + return nil, "x509.csr.new: expect nil or a string at #1" + end + ffi_gc(ctx, C.X509_REQ_free) + + local self = setmetatable({ + ctx = ctx, + }, mt) + + return self, nil +end + +function _M.istype(l) + return l and l and l.ctx and ffi.istype(x509_req_ptr_ct, l.ctx) +end + +function _M:tostring(fmt) + return __tostring(self, fmt) +end + +function _M:to_PEM() + return __tostring(self, "PEM") +end + +function _M:check_private_key(key) + if not pkey_lib.istype(key) then + return false, "x509.csr:check_private_key: except a pkey instance at #1" + end + + if not key:is_private() then + return false, "x509.csr:check_private_key: not a private key" + end + + if C.X509_REQ_check_private_key(self.ctx, key.ctx) == 1 then + return true + end + return false, format_error("x509.csr:check_private_key") +end + +--- Get all csr extensions +-- @tparam table self Instance of csr +-- @treturn Extensions object +function _M:get_extensions() + local extensions = C.X509_REQ_get_extensions(self.ctx) + -- GC handler is sk_X509_EXTENSION_pop_free + ffi_gc(extensions, x509_extensions_gc) + + return extensions_lib.dup(extensions) +end + +local function get_extension(ctx, nid_txt, last_pos) + local nid, err = txtnid2nid(nid_txt) + if err then + return nil, nil, err + end + + local extensions = C.X509_REQ_get_extensions(ctx) + if extensions == nil then + return nil, nil, format_error("csr.get_extension: X509_REQ_get_extensions") + end + ffi_gc(extensions, x509_extensions_gc) + + -- make 1-index array to 0-index + last_pos = (last_pos or 0) -1 + local ext_idx = C.X509v3_get_ext_by_NID(extensions, nid, last_pos) + if ext_idx == -1 then + err = ("X509v3_get_ext_by_NID extension for %d not found"):format(nid) + return nil, -1, format_error(err) + end + + local ctx = C.X509v3_get_ext(extensions, ext_idx) + if ctx == nil then + return nil, nil, format_error("X509v3_get_ext") + end + + return ctx, ext_idx, nil +end + +--- Get a csr extension +-- @tparam table self Instance of csr +-- @tparam string|number Nid number or name of the extension +-- @tparam number Position to start looking for the extension; default to look from start if omitted +-- @treturn Parsed extension object or nil if not found +function _M:get_extension(nid_txt, last_pos) + local ctx, pos, err = get_extension(self.ctx, nid_txt, last_pos) + if err then + return nil, nil, "x509.csr:get_extension: " .. err + end + local ext, err = extension_lib.dup(ctx) + if err then + return nil, nil, "x509.csr:get_extension: " .. err + end + return ext, pos+1 +end + +local function modify_extension(replace, ctx, nid, toset, crit) + local extensions_ptr = stack_ptr_type() + extensions_ptr[0] = C.X509_REQ_get_extensions(ctx) + local need_cleanup = extensions_ptr[0] ~= nil and + -- extensions_ptr being nil is fine: it may just because there's no extension yet + -- https://github.com/openssl/openssl/commit/2039ac07b401932fa30a05ade80b3626e189d78a + -- introduces a change that a empty stack instead of NULL will be returned in no extension + -- is found. so we need to double check the number if it's not NULL. + stack_macro.OPENSSL_sk_num(extensions_ptr[0]) > 0 + + local flag + if replace then + -- x509v3.h: # define X509V3_ADD_REPLACE 2L + flag = 0x2 + else + -- x509v3.h: # define X509V3_ADD_APPEND 1L + flag = 0x1 + end + + local code = C.X509V3_add1_i2d(extensions_ptr, nid, toset, crit and 1 or 0, flag) + -- when the stack is newly allocated, we want to cleanup the newly created stack as well + -- setting the gc handler here as it's mutated in X509V3_add1_i2d if it's pointing to NULL + ffi_gc(extensions_ptr[0], x509_extensions_gc) + if code ~= 1 then + return false, format_error("X509V3_add1_i2d", code) + end + + code = C.X509_REQ_add_extensions(ctx, extensions_ptr[0]) + if code ~= 1 then + return false, format_error("X509_REQ_add_extensions", code) + end + + if need_cleanup then + -- cleanup old attributes + -- delete the first only, why? + local attr = C.X509_REQ_delete_attr(ctx, 0) + if attr ~= nil then + C.X509_ATTRIBUTE_free(attr) + end + end + + -- mark encoded form as invalid so next time it will be re-encoded + if OPENSSL_11_OR_LATER then + C.i2d_re_X509_REQ_tbs(ctx, nil) + else + ctx.req_info.enc.modified = 1 + end + + return true +end + +local function add_extension(...) + return modify_extension(false, ...) +end + +local function replace_extension(...) + return modify_extension(true, ...) +end + +function _M:add_extension(extension) + if not extension_lib.istype(extension) then + return false, "x509:set_extension: expect a x509.extension instance at #1" + end + + local nid = extension:get_object().nid + local toset = extension_lib.to_data(extension, nid) + return add_extension(self.ctx, nid, toset.ctx, extension:get_critical()) +end + +function _M:set_extension(extension) + if not extension_lib.istype(extension) then + return false, "x509:set_extension: expect a x509.extension instance at #1" + end + + local nid = extension:get_object().nid + local toset = extension_lib.to_data(extension, nid) + return replace_extension(self.ctx, nid, toset.ctx, extension:get_critical()) +end + +function _M:set_extension_critical(nid_txt, crit, last_pos) + local nid, err = txtnid2nid(nid_txt) + if err then + return nil, "x509.csr:set_extension_critical: " .. err + end + + local extension, _, err = get_extension(self.ctx, nid, last_pos) + if err then + return nil, "x509.csr:set_extension_critical: " .. err + end + + local toset = extension_lib.to_data({ + ctx = extension + }, nid) + return replace_extension(self.ctx, nid, toset.ctx, crit and 1 or 0) +end + +function _M:get_extension_critical(nid_txt, last_pos) + local ctx, _, err = get_extension(self.ctx, nid_txt, last_pos) + if err then + return nil, "x509.csr:get_extension_critical: " .. err + end + + return C.X509_EXTENSION_get_critical(ctx) == 1 +end + +-- START AUTO GENERATED CODE + +-- AUTO GENERATED +function _M:sign(pkey, digest) + if not pkey_lib.istype(pkey) then + return false, "x509.csr:sign: expect a pkey instance at #1" + end + + local digest_algo + if digest then + if not digest_lib.istype(digest) then + return false, "x509.csr:sign: expect a digest instance at #2" + elseif not digest.algo then + return false, "x509.csr:sign: expect a digest instance to have algo member" + end + digest_algo = digest.algo + elseif BORINGSSL then + digest_algo = C.EVP_get_digestbyname('sha256') + end + + -- returns size of signature if success + if C.X509_REQ_sign(self.ctx, pkey.ctx, digest_algo) == 0 then + return false, format_error("x509.csr:sign") + end + + return true +end + +-- AUTO GENERATED +function _M:verify(pkey) + if not pkey_lib.istype(pkey) then + return false, "x509.csr:verify: expect a pkey instance at #1" + end + + local code = C.X509_REQ_verify(self.ctx, pkey.ctx) + if code == 1 then + return true + elseif code == 0 then + return false + else -- typically -1 + return false, format_error("x509.csr:verify", code) + end +end +-- AUTO GENERATED +function _M:get_subject_name() + local got = accessors.get_subject_name(self.ctx) + if got == nil then + return nil + end + local lib = require("resty.openssl.x509.name") + -- the internal ptr is returned, ie we need to copy it + return lib.dup(got) +end + +-- AUTO GENERATED +function _M:set_subject_name(toset) + local lib = require("resty.openssl.x509.name") + if lib.istype and not lib.istype(toset) then + return false, "x509.csr:set_subject_name: expect a x509.name instance at #1" + end + toset = toset.ctx + if accessors.set_subject_name(self.ctx, toset) == 0 then + return false, format_error("x509.csr:set_subject_name") + end + return true +end + +-- AUTO GENERATED +function _M:get_pubkey() + local got = accessors.get_pubkey(self.ctx) + if got == nil then + return nil + end + local lib = require("resty.openssl.pkey") + -- returned a copied instance directly + return lib.new(got) +end + +-- AUTO GENERATED +function _M:set_pubkey(toset) + local lib = require("resty.openssl.pkey") + if lib.istype and not lib.istype(toset) then + return false, "x509.csr:set_pubkey: expect a pkey instance at #1" + end + toset = toset.ctx + if accessors.set_pubkey(self.ctx, toset) == 0 then + return false, format_error("x509.csr:set_pubkey") + end + return true +end + +-- AUTO GENERATED +function _M:get_version() + local got = accessors.get_version(self.ctx) + if got == nil then + return nil + end + + got = tonumber(got) + 1 + + return got +end + +-- AUTO GENERATED +function _M:set_version(toset) + if type(toset) ~= "number" then + return false, "x509.csr:set_version: expect a number at #1" + end + + -- Note: this is defined by standards (X.509 et al) to be one less than the certificate version. + -- So a version 3 certificate will return 2 and a version 1 certificate will return 0. + toset = toset - 1 + + if accessors.set_version(self.ctx, toset) == 0 then + return false, format_error("x509.csr:set_version") + end + return true +end + +local NID_subject_alt_name = C.OBJ_sn2nid("subjectAltName") +assert(NID_subject_alt_name ~= 0) + +-- AUTO GENERATED: EXTENSIONS +function _M:get_subject_alt_name() + local crit = ctypes.ptr_of_int() + local extensions = C.X509_REQ_get_extensions(self.ctx) + -- GC handler is sk_X509_EXTENSION_pop_free + ffi_gc(extensions, x509_extensions_gc) + local got = C.X509V3_get_d2i(extensions, NID_subject_alt_name, crit, nil) + crit = tonumber(crit[0]) + if crit == -1 then -- not found + return nil + elseif crit == -2 then + return nil, "x509.csr:get_subject_alt_name: extension of subject_alt_name occurs more than one times, " .. + "this is not yet implemented. Please use get_extension instead." + elseif got == nil then + return nil, format_error("x509.csr:get_subject_alt_name") + end + + -- Note: here we only free the stack itself not elements + -- since there seems no way to increase ref count for a GENERAL_NAME + -- we left the elements referenced by the new-dup'ed stack + local got_ref = got + ffi_gc(got_ref, stack_lib.gc_of("GENERAL_NAME")) + got = ffi_cast("GENERAL_NAMES*", got_ref) + local lib = require("resty.openssl.x509.altname") + -- the internal ptr is returned, ie we need to copy it + return lib.dup(got) +end + +-- AUTO GENERATED: EXTENSIONS +function _M:set_subject_alt_name(toset) + local lib = require("resty.openssl.x509.altname") + if lib.istype and not lib.istype(toset) then + return false, "x509.csr:set_subject_alt_name: expect a x509.altname instance at #1" + end + toset = toset.ctx + return replace_extension(self.ctx, NID_subject_alt_name, toset) +end + +-- AUTO GENERATED: EXTENSIONS +function _M:set_subject_alt_name_critical(crit) + return _M.set_extension_critical(self, NID_subject_alt_name, crit) +end + +-- AUTO GENERATED: EXTENSIONS +function _M:get_subject_alt_name_critical() + return _M.get_extension_critical(self, NID_subject_alt_name) +end + + +-- AUTO GENERATED +function _M:get_signature_nid() + local nid = accessors.get_signature_nid(self.ctx) + if nid <= 0 then + return nil, format_error("x509.csr:get_signature_nid") + end + + return nid +end + +-- AUTO GENERATED +function _M:get_signature_name() + local nid = accessors.get_signature_nid(self.ctx) + if nid <= 0 then + return nil, format_error("x509.csr:get_signature_name") + end + + return ffi.string(C.OBJ_nid2sn(nid)) +end + +-- AUTO GENERATED +function _M:get_signature_digest_name() + local nid = accessors.get_signature_nid(self.ctx) + if nid <= 0 then + return nil, format_error("x509.csr:get_signature_digest_name") + end + + local nid = find_sigid_algs(nid) + + return ffi.string(C.OBJ_nid2sn(nid)) +end +-- END AUTO GENERATED CODE + +return _M + diff --git a/server/resty/openssl/x509/extension.lua b/server/resty/openssl/x509/extension.lua new file mode 100644 index 0000000..ca23158 --- /dev/null +++ b/server/resty/openssl/x509/extension.lua @@ -0,0 +1,281 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc +local ffi_new = ffi.new +local ffi_cast = ffi.cast +local ffi_str = ffi.string + +require "resty.openssl.include.x509" +require "resty.openssl.include.x509.extension" +require "resty.openssl.include.x509v3" +require "resty.openssl.include.bio" +require "resty.openssl.include.conf" +local asn1_macro = require("resty.openssl.include.asn1") +local objects_lib = require "resty.openssl.objects" +local stack_lib = require("resty.openssl.stack") +local bio_util = require "resty.openssl.auxiliary.bio" +local format_error = require("resty.openssl.err").format_error +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X +local BORINGSSL = require("resty.openssl.version").BORINGSSL + +local _M = {} +local mt = { __index = _M } + +local x509_extension_ptr_ct = ffi.typeof("X509_EXTENSION*") + +local extension_types = { + issuer = "resty.openssl.x509", + subject = "resty.openssl.x509", + request = "resty.openssl.x509.csr", + crl = "resty.openssl.x509.crl", +} + +if OPENSSL_3X then + extension_types["issuer_pkey"] = "resty.openssl.pkey" +end + +local nconf_load +if BORINGSSL then + nconf_load = function() + return nil, "NCONF_load_bio not exported in BoringSSL" + end +else + nconf_load = function(conf, str) + local bio = C.BIO_new_mem_buf(str, #str) + if bio == nil then + return format_error("BIO_new_mem_buf") + end + ffi_gc(bio, C.BIO_free) + + if C.NCONF_load_bio(conf, bio, nil) ~= 1 then + return format_error("NCONF_load_bio") + end + end +end + +function _M.new(txtnid, value, data) + local nid, err = objects_lib.txtnid2nid(txtnid) + if err then + return nil, "x509.extension.new: " .. err + end + if type(value) ~= 'string' then + return nil, "x509.extension.new: expect string at #2" + end + -- get a ptr and also zerofill the struct + local x509_ctx_ptr = ffi_new('X509V3_CTX[1]') + + local conf = C.NCONF_new(nil) + if conf == nil then + return nil, format_error("NCONF_new") + end + ffi_gc(conf, C.NCONF_free) + + if type(data) == 'table' then + local args = {} + if data.db then + if type(data.db) ~= 'string' then + return nil, "x509.extension.new: expect data.db must be a string" + end + err = nconf_load(conf, data) + if err then + return nil, "x509.extension.new: " .. err + end + end + + for k, t in pairs(extension_types) do + if data[k] then + local lib = require(t) + if not lib.istype(data[k]) then + return nil, "x509.extension.new: expect data." .. k .. " to be a " .. t .. " instance" + end + args[k] = data[k].ctx + end + end + C.X509V3_set_ctx(x509_ctx_ptr[0], args.issuer, args.subject, args.request, args.crl, 0) + + if OPENSSL_3X and args.issuer_pkey then + if C.X509V3_set_issuer_pkey(x509_ctx_ptr[0], args.issuer_pkey) ~= 1 then + return nil, format_error("x509.extension.new: X509V3_set_issuer_pkey") + end + end + + elseif type(data) == 'string' then + err = nconf_load(conf, data) + if err then + return nil, "x509.extension.new: " .. err + end + elseif data then + return nil, "x509.extension.new: expect nil, string a table at #3" + end + + -- setting conf is required for some extensions to load + -- crypto/x509/v3_conf.c:do_ext_conf "else if (method->r2i) {" branch + C.X509V3_set_nconf(x509_ctx_ptr[0], conf) + + local ctx = C.X509V3_EXT_nconf_nid(conf, x509_ctx_ptr[0], nid, value) + if ctx == nil then + return nil, format_error("x509.extension.new") + end + ffi_gc(ctx, C.X509_EXTENSION_free) + + local self = setmetatable({ + ctx = ctx, + }, mt) + + return self, nil +end + +function _M.istype(l) + return l and l.ctx and ffi.istype(x509_extension_ptr_ct, l.ctx) +end + +function _M.dup(ctx) + if not ffi.istype(x509_extension_ptr_ct, ctx) then + return nil, "x509.extension.dup: expect a x509.extension ctx at #1" + end + local ctx = C.X509_EXTENSION_dup(ctx) + if ctx == nil then + return nil, "x509.extension.dup: X509_EXTENSION_dup() failed" + end + + ffi_gc(ctx, C.X509_EXTENSION_free) + + local self = setmetatable({ + ctx = ctx, + }, mt) + + return self, nil +end + +function _M.from_der(value, txtnid, crit) + local nid, err = objects_lib.txtnid2nid(txtnid) + if err then + return nil, "x509.extension.from_der: " .. err + end + if type(value) ~= 'string' then + return nil, "x509.extension.from_der: expect string at #1" + end + + local asn1 = C.ASN1_STRING_new() + if asn1 == nil then + return nil, format_error("x509.extension.from_der: ASN1_STRING_new") + end + ffi_gc(asn1, C.ASN1_STRING_free) + + if C.ASN1_STRING_set(asn1, value, #value) ~= 1 then + return nil, format_error("x509.extension.from_der: ASN1_STRING_set") + end + + local ctx = C.X509_EXTENSION_create_by_NID(nil, nid, crit and 1 or 0, asn1) + if ctx == nil then + return nil, format_error("x509.extension.from_der: X509_EXTENSION_create_by_NID") + end + ffi_gc(ctx, C.X509_EXTENSION_free) + + local self = setmetatable({ + ctx = ctx, + }, mt) + + return self, nil +end + +function _M:to_der() + local asn1 = C.X509_EXTENSION_get_data(self.ctx) + + return ffi_str(asn1_macro.ASN1_STRING_get0_data(asn1)) +end + +function _M.from_data(any, txtnid, crit) + local nid, err = objects_lib.txtnid2nid(txtnid) + if err then + return nil, "x509.extension.from_der: " .. err + end + + if type(any) ~= "table" or type(any.ctx) ~= "cdata" then + return nil, "x509.extension.from_data: expect a table with ctx at #1" + elseif type(nid) ~= "number" then + return nil, "x509.extension.from_data: expect a table at #2" + end + + local ctx = C.X509V3_EXT_i2d(nid, crit and 1 or 0, any.ctx) + if ctx == nil then + return nil, format_error("x509.extension.from_data: X509V3_EXT_i2d") + end + ffi_gc(ctx, C.X509_EXTENSION_free) + + local self = setmetatable({ + ctx = ctx, + }, mt) + + return self, nil +end + +local NID_subject_alt_name = C.OBJ_sn2nid("subjectAltName") +assert(NID_subject_alt_name ~= 0) + +function _M.to_data(extension, nid) + if not _M.istype(extension) then + return nil, "x509.extension.dup: expect a x509.extension ctx at #1" + elseif type(nid) ~= "number" then + return nil, "x509.extension.to_data: expect a table at #2" + end + + local void_ptr = C.X509V3_EXT_d2i(extension.ctx) + if void_ptr == nil then + return nil, format_error("x509.extension:to_data: X509V3_EXT_d2i") + end + + if nid == NID_subject_alt_name then + -- Note: here we only free the stack itself not elements + -- since there seems no way to increase ref count for a GENERAL_NAME + -- we left the elements referenced by the new-dup'ed stack + ffi_gc(void_ptr, stack_lib.gc_of("GENERAL_NAME")) + local got = ffi_cast("GENERAL_NAMES*", void_ptr) + local lib = require("resty.openssl.x509.altname") + -- the internal ptr is returned, ie we need to copy it + return lib.dup(got) + end + + return nil, string.format("x509.extension:to_data: don't know how to convert to NID %d", nid) +end + +function _M:get_object() + -- retruns the internal pointer + local asn1 = C.X509_EXTENSION_get_object(self.ctx) + + return objects_lib.obj2table(asn1) +end + +function _M:get_critical() + return C.X509_EXTENSION_get_critical(self.ctx) == 1 +end + +function _M:set_critical(crit) + if C.X509_EXTENSION_set_critical(self.ctx, crit and 1 or 0) ~= 1 then + return false, format_error("x509.extension:set_critical") + end + return true +end + +function _M:tostring() + local ret, err = bio_util.read_wrap(C.X509V3_EXT_print, self.ctx, 0, 0) + if not err then + return ret + end + -- fallback to ASN.1 print + local asn1 = C.X509_EXTENSION_get_data(self.ctx) + return bio_util.read_wrap(C.ASN1_STRING_print, asn1) +end + +_M.text = _M.tostring + +mt.__tostring = function(tbl) + local txt, err = _M.text(tbl) + if err then + error(err) + end + return txt +end + + +return _M diff --git a/server/resty/openssl/x509/extension/dist_points.lua b/server/resty/openssl/x509/extension/dist_points.lua new file mode 100644 index 0000000..b1d419b --- /dev/null +++ b/server/resty/openssl/x509/extension/dist_points.lua @@ -0,0 +1,75 @@ +local ffi = require "ffi" + +require "resty.openssl.include.x509" +require "resty.openssl.include.x509v3" +local altname_lib = require "resty.openssl.x509.altname" +local stack_lib = require "resty.openssl.stack" + +local _M = {} + +local stack_ptr_ct = ffi.typeof("OPENSSL_STACK*") + +local STACK = "DIST_POINT" +local new = stack_lib.new_of(STACK) +local dup = stack_lib.dup_of(STACK) + +-- TODO: return other attributes? +local cdp_decode_fullname = function(ctx) + return altname_lib.dup(ctx.distpoint.name.fullname) +end + +local mt = stack_lib.mt_of(STACK, cdp_decode_fullname, _M) + +function _M.new() + local ctx = new() + if ctx == nil then + return nil, "OPENSSL_sk_new_null() failed" + end + + local self = setmetatable({ + ctx = ctx, + _is_shallow_copy = false, + }, mt) + + return self, nil +end + +function _M.istype(l) + return l and l.cast and ffi.istype(stack_ptr_ct, l.cast) +end + +function _M.dup(ctx) + if ctx == nil or not ffi.istype(stack_ptr_ct, ctx) then + return nil, "expect a stack ctx at #1" + end + local dup_ctx = dup(ctx) + + return setmetatable({ + ctx = dup_ctx, + -- don't let lua gc the original stack to keep its elements + _dupped_from = ctx, + _is_shallow_copy = true, + _elem_refs = {}, + _elem_refs_idx = 1, + }, mt), nil +end + +_M.all = function(stack) + local ret = {} + local _next = mt.__ipairs(stack) + while true do + local i, e = _next() + if i then + ret[i] = e + else + break + end + end + return ret +end + +_M.each = mt.__ipairs +_M.index = mt.__index +_M.count = mt.__len + +return _M diff --git a/server/resty/openssl/x509/extension/info_access.lua b/server/resty/openssl/x509/extension/info_access.lua new file mode 100644 index 0000000..21025a8 --- /dev/null +++ b/server/resty/openssl/x509/extension/info_access.lua @@ -0,0 +1,137 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc +local ffi_cast = ffi.cast + +require "resty.openssl.include.x509" +require "resty.openssl.include.x509v3" +require "resty.openssl.include.err" +local altname_lib = require "resty.openssl.x509.altname" +local stack_lib = require "resty.openssl.stack" + +local _M = {} + +local authority_info_access_ptr_ct = ffi.typeof("AUTHORITY_INFO_ACCESS*") + +local STACK = "ACCESS_DESCRIPTION" +local new = stack_lib.new_of(STACK) +local add = stack_lib.add_of(STACK) +local dup = stack_lib.dup_of(STACK) + +local aia_decode = function(ctx) + local nid = C.OBJ_obj2nid(ctx.method) + local gn = altname_lib.gn_decode(ctx.location) + return { nid, unpack(gn) } +end + +local mt = stack_lib.mt_of(STACK, aia_decode, _M) +local mt__pairs = mt.__pairs +mt.__pairs = function(tbl) + local f = mt__pairs(tbl) + return function() + local _, e = f() + if not e then return end + return unpack(e) + end +end + +function _M.new() + local ctx = new() + if ctx == nil then + return nil, "OPENSSL_sk_new_null() failed" + end + local cast = ffi_cast("AUTHORITY_INFO_ACCESS*", ctx) + + local self = setmetatable({ + ctx = ctx, + cast = cast, + _is_shallow_copy = false, + }, mt) + + return self, nil +end + +function _M.istype(l) + return l and l.cast and ffi.istype(authority_info_access_ptr_ct, l.cast) +end + +function _M.dup(ctx) + if ctx == nil or not ffi.istype(authority_info_access_ptr_ct, ctx) then + return nil, "expect a AUTHORITY_INFO_ACCESS* ctx at #1" + end + local dup_ctx = dup(ctx) + + return setmetatable({ + ctx = dup_ctx, + cast = ffi_cast("AUTHORITY_INFO_ACCESS*", dup_ctx), + -- don't let lua gc the original stack to keep its elements + _dupped_from = ctx, + _is_shallow_copy = true, + _elem_refs = {}, + _elem_refs_idx = 1, + }, mt), nil +end + +function _M:add(nid, typ, value) + -- the stack element stays with stack + -- we shouldn't add gc handler if it's already been + -- pushed to stack. instead, rely on the gc handler + -- of the stack to release all memories + local ad = C.ACCESS_DESCRIPTION_new() + if ad == nil then + return nil, "ACCESS_DESCRIPTION_new() failed" + end + + -- C.ASN1_OBJECT_free(ad.method) + + local asn1 = C.OBJ_txt2obj(nid, 0) + if asn1 == nil then + C.ACCESS_DESCRIPTION_free(ad) + -- clean up error occurs during OBJ_txt2* + C.ERR_clear_error() + return nil, "invalid NID text " .. (nid or "nil") + end + + ad.method = asn1 + + local err = altname_lib.gn_set(ad.location, typ, value) + if err then + C.ACCESS_DESCRIPTION_free(ad) + return nil, err + end + + local _, err = add(self.ctx, ad) + if err then + C.ACCESS_DESCRIPTION_free(ad) + return nil, err + end + + -- if the stack is duplicated, the gc handler is not pop_free + -- handle the gc by ourselves + if self._is_shallow_copy then + ffi_gc(ad, C.ACCESS_DESCRIPTION_free) + self._elem_refs[self._elem_refs_idx] = ad + self._elem_refs_idx = self._elem_refs_idx + 1 + end + return self +end + +_M.all = function(stack) + local ret = {} + local _next = mt.__ipairs(stack) + while true do + local i, e = _next() + if i then + ret[i] = e + else + break + end + end + return ret +end + +_M.each = mt.__ipairs +_M.index = mt.__index +_M.count = mt.__len + +return _M diff --git a/server/resty/openssl/x509/extensions.lua b/server/resty/openssl/x509/extensions.lua new file mode 100644 index 0000000..3b64b8a --- /dev/null +++ b/server/resty/openssl/x509/extensions.lua @@ -0,0 +1,84 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc + +local stack_lib = require "resty.openssl.stack" +local extension_lib = require "resty.openssl.x509.extension" +local format_error = require("resty.openssl.err").format_error + +local _M = {} + +local stack_ptr_ct = ffi.typeof("OPENSSL_STACK*") + +local STACK = "X509_EXTENSION" +local new = stack_lib.new_of(STACK) +local add = stack_lib.add_of(STACK) +local dup = stack_lib.dup_of(STACK) +local mt = stack_lib.mt_of(STACK, extension_lib.dup, _M) + +function _M.new() + local raw = new() + + local self = setmetatable({ + stack_of = STACK, + ctx = raw, + }, mt) + + return self, nil +end + +function _M.istype(l) + return l and l.ctx and ffi.istype(stack_ptr_ct, l.ctx) + and l.stack_of and l.stack_of == STACK +end + +function _M.dup(ctx) + if ctx == nil or not ffi.istype(stack_ptr_ct, ctx) then + return nil, "x509.extensions.dup: expect a stack ctx at #1, got " .. type(ctx) + end + + local dup_ctx = dup(ctx) + + return setmetatable({ + ctx = dup_ctx, + -- don't let lua gc the original stack to keep its elements + _dupped_from = ctx, + _is_shallow_copy = true, + _elem_refs = {}, + _elem_refs_idx = 1, + }, mt), nil +end + +function _M:add(extension) + if not extension_lib.istype(extension) then + return nil, "expect a x509.extension instance at #1" + end + + local dup = C.X509_EXTENSION_dup(extension.ctx) + if dup == nil then + return nil, format_error("extensions:add: X509_EXTENSION_dup") + end + + local _, err = add(self.ctx, dup) + if err then + C.X509_EXTENSION_free(dup) + return nil, err + end + + -- if the stack is duplicated, the gc handler is not pop_free + -- handle the gc by ourselves + if self._is_shallow_copy then + ffi_gc(dup, C.X509_EXTENSION_free) + self._elem_refs[self._elem_refs_idx] = dup + self._elem_refs_idx = self._elem_refs_idx + 1 + end + + return true +end + +_M.all = stack_lib.all_func(mt) +_M.each = mt.__ipairs +_M.index = mt.__index +_M.count = mt.__len + +return _M diff --git a/server/resty/openssl/x509/init.lua b/server/resty/openssl/x509/init.lua new file mode 100644 index 0000000..5c259c8 --- /dev/null +++ b/server/resty/openssl/x509/init.lua @@ -0,0 +1,1071 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc +local ffi_str = ffi.string +local ffi_cast = ffi.cast + +require "resty.openssl.include.x509" +require "resty.openssl.include.x509v3" +require "resty.openssl.include.evp" +require "resty.openssl.include.objects" +local stack_macro = require("resty.openssl.include.stack") +local stack_lib = require("resty.openssl.stack") +local asn1_lib = require("resty.openssl.asn1") +local digest_lib = require("resty.openssl.digest") +local extension_lib = require("resty.openssl.x509.extension") +local pkey_lib = require("resty.openssl.pkey") +local bio_util = require "resty.openssl.auxiliary.bio" +local txtnid2nid = require("resty.openssl.objects").txtnid2nid +local find_sigid_algs = require("resty.openssl.objects").find_sigid_algs +local ctypes = require "resty.openssl.auxiliary.ctypes" +local ctx_lib = require "resty.openssl.ctx" +local format_error = require("resty.openssl.err").format_error +local version = require("resty.openssl.version") +local OPENSSL_10 = version.OPENSSL_10 +local OPENSSL_11_OR_LATER = version.OPENSSL_11_OR_LATER +local OPENSSL_3X = version.OPENSSL_3X +local BORINGSSL = version.BORINGSSL +local BORINGSSL_110 = version.BORINGSSL_110 -- used in boringssl-fips-20190808 + +-- accessors provides an openssl version neutral interface to lua layer +-- it doesn't handle any error, expect that to be implemented in +-- _M.set_X or _M.get_X +local accessors = {} + +accessors.get_pubkey = C.X509_get_pubkey -- returns new evp_pkey instance, don't need to dup +accessors.set_pubkey = C.X509_set_pubkey +accessors.set_version = C.X509_set_version +accessors.set_serial_number = C.X509_set_serialNumber +accessors.get_subject_name = C.X509_get_subject_name -- returns internal ptr, we dup it +accessors.set_subject_name = C.X509_set_subject_name +accessors.get_issuer_name = C.X509_get_issuer_name -- returns internal ptr, we dup it +accessors.set_issuer_name = C.X509_set_issuer_name +accessors.get_signature_nid = C.X509_get_signature_nid + +-- generally, use get1 if we return a lua table wrapped ctx which doesn't support dup. +-- in that case, a new struct is returned from C api, and we will handle gc. +-- openssl will increment the reference count for returned ptr, and won't free it when +-- parent struct is freed. +-- otherwise, use get0, which returns an internal pointer, we don't need to free it up. +-- it will be gone together with the parent struct. + +if BORINGSSL_110 then + accessors.get_not_before = C.X509_get0_notBefore -- returns internal ptr, we convert to number + accessors.set_not_before = C.X509_set_notBefore + accessors.get_not_after = C.X509_get0_notAfter -- returns internal ptr, we convert to number + accessors.set_not_after = C.X509_set_notAfter + accessors.get_version = function(x509) + if x509 == nil or x509.cert_info == nil or x509.cert_info.validity == nil then + return nil + end + return C.ASN1_INTEGER_get(x509.cert_info.version) + end + accessors.get_serial_number = C.X509_get_serialNumber -- returns internal ptr, we convert to bn +elseif OPENSSL_11_OR_LATER then + accessors.get_not_before = C.X509_get0_notBefore -- returns internal ptr, we convert to number + accessors.set_not_before = C.X509_set1_notBefore + accessors.get_not_after = C.X509_get0_notAfter -- returns internal ptr, we convert to number + accessors.set_not_after = C.X509_set1_notAfter + accessors.get_version = C.X509_get_version -- returns int + accessors.get_serial_number = C.X509_get0_serialNumber -- returns internal ptr, we convert to bn +elseif OPENSSL_10 then + accessors.get_not_before = function(x509) + if x509 == nil or x509.cert_info == nil or x509.cert_info.validity == nil then + return nil + end + return x509.cert_info.validity.notBefore + end + accessors.set_not_before = C.X509_set_notBefore + accessors.get_not_after = function(x509) + if x509 == nil or x509.cert_info == nil or x509.cert_info.validity == nil then + return nil + end + return x509.cert_info.validity.notAfter + end + accessors.set_not_after = C.X509_set_notAfter + accessors.get_version = function(x509) + if x509 == nil or x509.cert_info == nil or x509.cert_info.validity == nil then + return nil + end + return C.ASN1_INTEGER_get(x509.cert_info.version) + end + accessors.get_serial_number = C.X509_get_serialNumber -- returns internal ptr, we convert to bn +end + +local function __tostring(self, fmt) + if not fmt or fmt == 'PEM' then + return bio_util.read_wrap(C.PEM_write_bio_X509, self.ctx) + elseif fmt == 'DER' then + return bio_util.read_wrap(C.i2d_X509_bio, self.ctx) + else + return nil, "x509:tostring: can only write PEM or DER format, not " .. fmt + end +end + +local _M = {} +local mt = { __index = _M, __tostring = __tostring } + + +local x509_ptr_ct = ffi.typeof("X509*") + +-- only PEM format is supported for now +function _M.new(cert, fmt, properties) + local ctx + if not cert then + -- routine for create a new cert + if OPENSSL_3X then + ctx = C.X509_new_ex(ctx_lib.get_libctx(), properties) + else + ctx = C.X509_new() + end + if ctx == nil then + return nil, format_error("x509.new") + end + ffi_gc(ctx, C.X509_free) + + C.X509_gmtime_adj(accessors.get_not_before(ctx), 0) + C.X509_gmtime_adj(accessors.get_not_after(ctx), 0) + elseif type(cert) == "string" then + -- routine for load an existing cert + local bio = C.BIO_new_mem_buf(cert, #cert) + if bio == nil then + return nil, format_error("x509.new: BIO_new_mem_buf") + end + + fmt = fmt or "*" + while true do -- luacheck: ignore 512 -- loop is executed at most once + if fmt == "PEM" or fmt == "*" then + ctx = C.PEM_read_bio_X509(bio, nil, nil, nil) + if ctx ~= nil then + break + elseif fmt == "*" then + -- BIO_reset; #define BIO_CTRL_RESET 1 + local code = C.BIO_ctrl(bio, 1, 0, nil) + if code ~= 1 then + C.BIO_free(bio) + return nil, "x509.new: BIO_ctrl() failed: " .. code + end + end + end + if fmt == "DER" or fmt == "*" then + ctx = C.d2i_X509_bio(bio, nil) + end + break + end + C.BIO_free(bio) + if ctx == nil then + return nil, format_error("x509.new") + end + -- clear errors occur when trying + C.ERR_clear_error() + ffi_gc(ctx, C.X509_free) + elseif type(cert) == 'cdata' then + if ffi.istype(x509_ptr_ct, cert) then + ctx = cert + ffi_gc(ctx, C.X509_free) + else + return nil, "x509.new: expect a X509* cdata at #1" + end + else + return nil, "x509.new: expect nil or a string at #1" + end + + local self = setmetatable({ + ctx = ctx, + }, mt) + + return self, nil +end + +function _M.istype(l) + return l and l.ctx and ffi.istype(x509_ptr_ct, l.ctx) +end + +function _M.dup(ctx) + if not ffi.istype(x509_ptr_ct, ctx) then + return nil, "x509.dup: expect a x509 ctx at #1" + end + local ctx = C.X509_dup(ctx) + if ctx == nil then + return nil, "x509.dup: X509_dup() failed" + end + + ffi_gc(ctx, C.X509_free) + + local self = setmetatable({ + ctx = ctx, + }, mt) + + return self, nil +end + +function _M:tostring(fmt) + return __tostring(self, fmt) +end + +function _M:to_PEM() + return __tostring(self, "PEM") +end + +function _M:set_lifetime(not_before, not_after) + local ok, err + if not_before then + ok, err = self:set_not_before(not_before) + if err then + return ok, err + end + end + + if not_after then + ok, err = self:set_not_after(not_after) + if err then + return ok, err + end + end + + return true +end + +function _M:get_lifetime() + local not_before, err = self:get_not_before() + if not_before == nil then + return nil, nil, err + end + local not_after, err = self:get_not_after() + if not_after == nil then + return nil, nil, err + end + + return not_before, not_after, nil +end + +-- note: index is 0 based +local OPENSSL_STRING_value_at = function(ctx, i) + local ct = ffi_cast("OPENSSL_STRING", stack_macro.OPENSSL_sk_value(ctx, i)) + if ct == nil then + return nil + end + return ffi_str(ct) +end + +function _M:get_ocsp_url(return_all) + local st = C.X509_get1_ocsp(self.ctx) + + local count = stack_macro.OPENSSL_sk_num(st) + if count == 0 then + return + end + + local ret + if return_all then + ret = {} + for i=0,count-1 do + ret[i+1] = OPENSSL_STRING_value_at(st, i) + end + else + ret = OPENSSL_STRING_value_at(st, 0) + end + + C.X509_email_free(st) + return ret +end + +function _M:get_ocsp_request() + +end + +function _M:get_crl_url(return_all) + local cdp, err = self:get_crl_distribution_points() + if err then + return nil, err + end + + if not cdp or cdp:count() == 0 then + return + end + + if return_all then + local ret = {} + local cdp_iter = cdp:each() + while true do + local _, gn = cdp_iter() + if not gn then + break + end + local gn_iter = gn:each() + while true do + local k, v = gn_iter() + if not k then + break + elseif k == "URI" then + table.insert(ret, v) + end + end + end + return ret + else + local gn, err = cdp:index(1) + if err then + return nil, err + end + local iter = gn:each() + while true do + local k, v = iter() + if not k then + break + elseif k == "URI" then + return v + end + end + end +end + +local digest_length = ctypes.ptr_of_uint() +local digest_buf, digest_buf_size +local function digest(self, cfunc, typ, properties) + -- TODO: dedup the following with resty.openssl.digest + local ctx + if OPENSSL_11_OR_LATER then + ctx = C.EVP_MD_CTX_new() + ffi_gc(ctx, C.EVP_MD_CTX_free) + elseif OPENSSL_10 then + ctx = C.EVP_MD_CTX_create() + ffi_gc(ctx, C.EVP_MD_CTX_destroy) + end + if ctx == nil then + return nil, "x509:digest: failed to create EVP_MD_CTX" + end + + local algo + if OPENSSL_3X then + algo = C.EVP_MD_fetch(ctx_lib.get_libctx(), typ or 'sha1', properties) + else + algo = C.EVP_get_digestbyname(typ or 'sha1') + end + if algo == nil then + return nil, string.format("x509:digest: invalid digest type \"%s\"", typ) + end + + local md_size = OPENSSL_3X and C.EVP_MD_get_size(algo) or C.EVP_MD_size(algo) + if not digest_buf or digest_buf_size < md_size then + digest_buf = ctypes.uchar_array(md_size) + digest_buf_size = md_size + end + + if cfunc(self.ctx, algo, digest_buf, digest_length) ~= 1 then + return nil, format_error("x509:digest") + end + + return ffi_str(digest_buf, digest_length[0]) +end + +function _M:digest(typ, properties) + return digest(self, C.X509_digest, typ, properties) +end + +function _M:pubkey_digest(typ, properties) + return digest(self, C.X509_pubkey_digest, typ, properties) +end + +function _M:check_private_key(key) + if not pkey_lib.istype(key) then + return false, "x509:check_private_key: except a pkey instance at #1" + end + + if not key:is_private() then + return false, "x509:check_private_key: not a private key" + end + + if C.X509_check_private_key(self.ctx, key.ctx) == 1 then + return true + end + return false, format_error("x509:check_private_key") +end + +-- START AUTO GENERATED CODE + +-- AUTO GENERATED +function _M:sign(pkey, digest) + if not pkey_lib.istype(pkey) then + return false, "x509:sign: expect a pkey instance at #1" + end + + local digest_algo + if digest then + if not digest_lib.istype(digest) then + return false, "x509:sign: expect a digest instance at #2" + elseif not digest.algo then + return false, "x509:sign: expect a digest instance to have algo member" + end + digest_algo = digest.algo + elseif BORINGSSL then + digest_algo = C.EVP_get_digestbyname('sha256') + end + + -- returns size of signature if success + if C.X509_sign(self.ctx, pkey.ctx, digest_algo) == 0 then + return false, format_error("x509:sign") + end + + return true +end + +-- AUTO GENERATED +function _M:verify(pkey) + if not pkey_lib.istype(pkey) then + return false, "x509:verify: expect a pkey instance at #1" + end + + local code = C.X509_verify(self.ctx, pkey.ctx) + if code == 1 then + return true + elseif code == 0 then + return false + else -- typically -1 + return false, format_error("x509:verify", code) + end +end + +-- AUTO GENERATED +local function get_extension(ctx, nid_txt, last_pos) + last_pos = (last_pos or 0) - 1 + local nid, err = txtnid2nid(nid_txt) + if err then + return nil, nil, err + end + local pos = C.X509_get_ext_by_NID(ctx, nid, last_pos) + if pos == -1 then + return nil + end + local ctx = C.X509_get_ext(ctx, pos) + if ctx == nil then + return nil, nil, format_error() + end + return ctx, pos +end + +-- AUTO GENERATED +function _M:add_extension(extension) + if not extension_lib.istype(extension) then + return false, "x509:add_extension: expect a x509.extension instance at #1" + end + + -- X509_add_ext returnes the stack on success, and NULL on error + -- the X509_EXTENSION ctx is dupped internally + if C.X509_add_ext(self.ctx, extension.ctx, -1) == nil then + return false, format_error("x509:add_extension") + end + + return true +end + +-- AUTO GENERATED +function _M:get_extension(nid_txt, last_pos) + local ctx, pos, err = get_extension(self.ctx, nid_txt, last_pos) + if err then + return nil, nil, "x509:get_extension: " .. err + end + local ext, err = extension_lib.dup(ctx) + if err then + return nil, nil, "x509:get_extension: " .. err + end + return ext, pos+1 +end + +local X509_delete_ext +if OPENSSL_11_OR_LATER then + X509_delete_ext = C.X509_delete_ext +elseif OPENSSL_10 then + X509_delete_ext = function(ctx, pos) + return C.X509v3_delete_ext(ctx.cert_info.extensions, pos) + end +else + X509_delete_ext = function(...) + error("X509_delete_ext undefined") + end +end + +-- AUTO GENERATED +function _M:set_extension(extension, last_pos) + if not extension_lib.istype(extension) then + return false, "x509:set_extension: expect a x509.extension instance at #1" + end + + last_pos = (last_pos or 0) - 1 + + local nid = extension:get_object().nid + local pos = C.X509_get_ext_by_NID(self.ctx, nid, last_pos) + -- pos may be -1, which means not found, it's fine, we will add new one instead of replace + + local removed = X509_delete_ext(self.ctx, pos) + C.X509_EXTENSION_free(removed) + + if C.X509_add_ext(self.ctx, extension.ctx, pos) == nil then + return false, format_error("x509:set_extension") + end + + return true +end + +-- AUTO GENERATED +function _M:set_extension_critical(nid_txt, crit, last_pos) + local ctx, _, err = get_extension(self.ctx, nid_txt, last_pos) + if err then + return nil, "x509:set_extension_critical: " .. err + end + + if C.X509_EXTENSION_set_critical(ctx, crit and 1 or 0) ~= 1 then + return false, format_error("x509:set_extension_critical") + end + + return true +end + +-- AUTO GENERATED +function _M:get_extension_critical(nid_txt, last_pos) + local ctx, _, err = get_extension(self.ctx, nid_txt, last_pos) + if err then + return nil, "x509:get_extension_critical: " .. err + end + + return C.X509_EXTENSION_get_critical(ctx) == 1 +end + +-- AUTO GENERATED +function _M:get_serial_number() + local got = accessors.get_serial_number(self.ctx) + if got == nil then + return nil + end + + -- returns a new BIGNUM instance + got = C.ASN1_INTEGER_to_BN(got, nil) + if got == nil then + return false, format_error("x509:set: BN_to_ASN1_INTEGER") + end + -- bn will be duplicated thus this ctx should be freed up + ffi_gc(got, C.BN_free) + + local lib = require("resty.openssl.bn") + -- the internal ptr is returned, ie we need to copy it + return lib.dup(got) +end + +-- AUTO GENERATED +function _M:set_serial_number(toset) + local lib = require("resty.openssl.bn") + if lib.istype and not lib.istype(toset) then + return false, "x509:set_serial_number: expect a bn instance at #1" + end + toset = toset.ctx + + toset = C.BN_to_ASN1_INTEGER(toset, nil) + if toset == nil then + return false, format_error("x509:set: BN_to_ASN1_INTEGER") + end + -- "A copy of the serial number is used internally + -- so serial should be freed up after use."" + ffi_gc(toset, C.ASN1_INTEGER_free) + + if accessors.set_serial_number(self.ctx, toset) == 0 then + return false, format_error("x509:set_serial_number") + end + return true +end + +-- AUTO GENERATED +function _M:get_not_before() + local got = accessors.get_not_before(self.ctx) + if got == nil then + return nil + end + + got = asn1_lib.asn1_to_unix(got) + + return got +end + +-- AUTO GENERATED +function _M:set_not_before(toset) + if type(toset) ~= "number" then + return false, "x509:set_not_before: expect a number at #1" + end + + toset = C.ASN1_TIME_set(nil, toset) + ffi_gc(toset, C.ASN1_STRING_free) + + if accessors.set_not_before(self.ctx, toset) == 0 then + return false, format_error("x509:set_not_before") + end + return true +end + +-- AUTO GENERATED +function _M:get_not_after() + local got = accessors.get_not_after(self.ctx) + if got == nil then + return nil + end + + got = asn1_lib.asn1_to_unix(got) + + return got +end + +-- AUTO GENERATED +function _M:set_not_after(toset) + if type(toset) ~= "number" then + return false, "x509:set_not_after: expect a number at #1" + end + + toset = C.ASN1_TIME_set(nil, toset) + ffi_gc(toset, C.ASN1_STRING_free) + + if accessors.set_not_after(self.ctx, toset) == 0 then + return false, format_error("x509:set_not_after") + end + return true +end + +-- AUTO GENERATED +function _M:get_pubkey() + local got = accessors.get_pubkey(self.ctx) + if got == nil then + return nil + end + local lib = require("resty.openssl.pkey") + -- returned a copied instance directly + return lib.new(got) +end + +-- AUTO GENERATED +function _M:set_pubkey(toset) + local lib = require("resty.openssl.pkey") + if lib.istype and not lib.istype(toset) then + return false, "x509:set_pubkey: expect a pkey instance at #1" + end + toset = toset.ctx + if accessors.set_pubkey(self.ctx, toset) == 0 then + return false, format_error("x509:set_pubkey") + end + return true +end + +-- AUTO GENERATED +function _M:get_subject_name() + local got = accessors.get_subject_name(self.ctx) + if got == nil then + return nil + end + local lib = require("resty.openssl.x509.name") + -- the internal ptr is returned, ie we need to copy it + return lib.dup(got) +end + +-- AUTO GENERATED +function _M:set_subject_name(toset) + local lib = require("resty.openssl.x509.name") + if lib.istype and not lib.istype(toset) then + return false, "x509:set_subject_name: expect a x509.name instance at #1" + end + toset = toset.ctx + if accessors.set_subject_name(self.ctx, toset) == 0 then + return false, format_error("x509:set_subject_name") + end + return true +end + +-- AUTO GENERATED +function _M:get_issuer_name() + local got = accessors.get_issuer_name(self.ctx) + if got == nil then + return nil + end + local lib = require("resty.openssl.x509.name") + -- the internal ptr is returned, ie we need to copy it + return lib.dup(got) +end + +-- AUTO GENERATED +function _M:set_issuer_name(toset) + local lib = require("resty.openssl.x509.name") + if lib.istype and not lib.istype(toset) then + return false, "x509:set_issuer_name: expect a x509.name instance at #1" + end + toset = toset.ctx + if accessors.set_issuer_name(self.ctx, toset) == 0 then + return false, format_error("x509:set_issuer_name") + end + return true +end + +-- AUTO GENERATED +function _M:get_version() + local got = accessors.get_version(self.ctx) + if got == nil then + return nil + end + + got = tonumber(got) + 1 + + return got +end + +-- AUTO GENERATED +function _M:set_version(toset) + if type(toset) ~= "number" then + return false, "x509:set_version: expect a number at #1" + end + + -- Note: this is defined by standards (X.509 et al) to be one less than the certificate version. + -- So a version 3 certificate will return 2 and a version 1 certificate will return 0. + toset = toset - 1 + + if accessors.set_version(self.ctx, toset) == 0 then + return false, format_error("x509:set_version") + end + return true +end + +local NID_subject_alt_name = C.OBJ_sn2nid("subjectAltName") +assert(NID_subject_alt_name ~= 0) + +-- AUTO GENERATED: EXTENSIONS +function _M:get_subject_alt_name() + local crit = ctypes.ptr_of_int() + -- X509_get_ext_d2i returns internal pointer, always dup + -- for now this function always returns the first found extension + local got = C.X509_get_ext_d2i(self.ctx, NID_subject_alt_name, crit, nil) + crit = tonumber(crit[0]) + if crit == -1 then -- not found + return nil + elseif crit == -2 then + return nil, "x509:get_subject_alt_name: extension of subject_alt_name occurs more than one times, " .. + "this is not yet implemented. Please use get_extension instead." + elseif got == nil then + return nil, format_error("x509:get_subject_alt_name") + end + + -- Note: here we only free the stack itself not elements + -- since there seems no way to increase ref count for a GENERAL_NAME + -- we left the elements referenced by the new-dup'ed stack + local got_ref = got + ffi_gc(got_ref, stack_lib.gc_of("GENERAL_NAME")) + got = ffi_cast("GENERAL_NAMES*", got_ref) + local lib = require("resty.openssl.x509.altname") + -- the internal ptr is returned, ie we need to copy it + return lib.dup(got) +end + +-- AUTO GENERATED: EXTENSIONS +function _M:set_subject_alt_name(toset) + local lib = require("resty.openssl.x509.altname") + if lib.istype and not lib.istype(toset) then + return false, "x509:set_subject_alt_name: expect a x509.altname instance at #1" + end + toset = toset.ctx + -- x509v3.h: # define X509V3_ADD_REPLACE 2L + if C.X509_add1_ext_i2d(self.ctx, NID_subject_alt_name, toset, 0, 0x2) ~= 1 then + return false, format_error("x509:set_subject_alt_name") + end + return true +end + +-- AUTO GENERATED: EXTENSIONS +function _M:set_subject_alt_name_critical(crit) + return _M.set_extension_critical(self, NID_subject_alt_name, crit) +end + +-- AUTO GENERATED: EXTENSIONS +function _M:get_subject_alt_name_critical() + return _M.get_extension_critical(self, NID_subject_alt_name) +end + +local NID_issuer_alt_name = C.OBJ_sn2nid("issuerAltName") +assert(NID_issuer_alt_name ~= 0) + +-- AUTO GENERATED: EXTENSIONS +function _M:get_issuer_alt_name() + local crit = ctypes.ptr_of_int() + -- X509_get_ext_d2i returns internal pointer, always dup + -- for now this function always returns the first found extension + local got = C.X509_get_ext_d2i(self.ctx, NID_issuer_alt_name, crit, nil) + crit = tonumber(crit[0]) + if crit == -1 then -- not found + return nil + elseif crit == -2 then + return nil, "x509:get_issuer_alt_name: extension of issuer_alt_name occurs more than one times, " .. + "this is not yet implemented. Please use get_extension instead." + elseif got == nil then + return nil, format_error("x509:get_issuer_alt_name") + end + + -- Note: here we only free the stack itself not elements + -- since there seems no way to increase ref count for a GENERAL_NAME + -- we left the elements referenced by the new-dup'ed stack + local got_ref = got + ffi_gc(got_ref, stack_lib.gc_of("GENERAL_NAME")) + got = ffi_cast("GENERAL_NAMES*", got_ref) + local lib = require("resty.openssl.x509.altname") + -- the internal ptr is returned, ie we need to copy it + return lib.dup(got) +end + +-- AUTO GENERATED: EXTENSIONS +function _M:set_issuer_alt_name(toset) + local lib = require("resty.openssl.x509.altname") + if lib.istype and not lib.istype(toset) then + return false, "x509:set_issuer_alt_name: expect a x509.altname instance at #1" + end + toset = toset.ctx + -- x509v3.h: # define X509V3_ADD_REPLACE 2L + if C.X509_add1_ext_i2d(self.ctx, NID_issuer_alt_name, toset, 0, 0x2) ~= 1 then + return false, format_error("x509:set_issuer_alt_name") + end + return true +end + +-- AUTO GENERATED: EXTENSIONS +function _M:set_issuer_alt_name_critical(crit) + return _M.set_extension_critical(self, NID_issuer_alt_name, crit) +end + +-- AUTO GENERATED: EXTENSIONS +function _M:get_issuer_alt_name_critical() + return _M.get_extension_critical(self, NID_issuer_alt_name) +end + +local NID_basic_constraints = C.OBJ_sn2nid("basicConstraints") +assert(NID_basic_constraints ~= 0) + +-- AUTO GENERATED: EXTENSIONS +function _M:get_basic_constraints(name) + local crit = ctypes.ptr_of_int() + -- X509_get_ext_d2i returns internal pointer, always dup + -- for now this function always returns the first found extension + local got = C.X509_get_ext_d2i(self.ctx, NID_basic_constraints, crit, nil) + crit = tonumber(crit[0]) + if crit == -1 then -- not found + return nil + elseif crit == -2 then + return nil, "x509:get_basic_constraints: extension of basic_constraints occurs more than one times, " .. + "this is not yet implemented. Please use get_extension instead." + elseif got == nil then + return nil, format_error("x509:get_basic_constraints") + end + + local ctx = ffi_cast("BASIC_CONSTRAINTS*", got) + + local ca = ctx.ca == 0xFF + local pathlen = tonumber(C.ASN1_INTEGER_get(ctx.pathlen)) + + C.BASIC_CONSTRAINTS_free(ctx) + + if not name or type(name) ~= "string" then + got = { + ca = ca, + pathlen = pathlen, + } + elseif string.lower(name) == "ca" then + got = ca + elseif string.lower(name) == "pathlen" then + got = pathlen + end + + return got +end + +-- AUTO GENERATED: EXTENSIONS +function _M:set_basic_constraints(toset) + if type(toset) ~= "table" then + return false, "x509:set_basic_constraints: expect a table at #1" + end + + local cfg_lower = {} + for k, v in pairs(toset) do + cfg_lower[string.lower(k)] = v + end + + toset = C.BASIC_CONSTRAINTS_new() + if toset == nil then + return false, format_error("x509:set_BASIC_CONSTRAINTS") + end + ffi_gc(toset, C.BASIC_CONSTRAINTS_free) + + toset.ca = cfg_lower.ca and 0xFF or 0 + local pathlen = cfg_lower.pathlen and tonumber(cfg_lower.pathlen) + if pathlen then + C.ASN1_INTEGER_free(toset.pathlen) + + local asn1 = C.ASN1_STRING_type_new(pathlen) + if asn1 == nil then + return false, format_error("x509:set_BASIC_CONSTRAINTS: ASN1_STRING_type_new") + end + toset.pathlen = asn1 + + local code = C.ASN1_INTEGER_set(asn1, pathlen) + if code ~= 1 then + return false, format_error("x509:set_BASIC_CONSTRAINTS: ASN1_INTEGER_set", code) + end + end + + -- x509v3.h: # define X509V3_ADD_REPLACE 2L + if C.X509_add1_ext_i2d(self.ctx, NID_basic_constraints, toset, 0, 0x2) ~= 1 then + return false, format_error("x509:set_basic_constraints") + end + return true +end + +-- AUTO GENERATED: EXTENSIONS +function _M:set_basic_constraints_critical(crit) + return _M.set_extension_critical(self, NID_basic_constraints, crit) +end + +-- AUTO GENERATED: EXTENSIONS +function _M:get_basic_constraints_critical() + return _M.get_extension_critical(self, NID_basic_constraints) +end + +local NID_info_access = C.OBJ_sn2nid("authorityInfoAccess") +assert(NID_info_access ~= 0) + +-- AUTO GENERATED: EXTENSIONS +function _M:get_info_access() + local crit = ctypes.ptr_of_int() + -- X509_get_ext_d2i returns internal pointer, always dup + -- for now this function always returns the first found extension + local got = C.X509_get_ext_d2i(self.ctx, NID_info_access, crit, nil) + crit = tonumber(crit[0]) + if crit == -1 then -- not found + return nil + elseif crit == -2 then + return nil, "x509:get_info_access: extension of info_access occurs more than one times, " .. + "this is not yet implemented. Please use get_extension instead." + elseif got == nil then + return nil, format_error("x509:get_info_access") + end + + -- Note: here we only free the stack itself not elements + -- since there seems no way to increase ref count for a ACCESS_DESCRIPTION + -- we left the elements referenced by the new-dup'ed stack + local got_ref = got + ffi_gc(got_ref, stack_lib.gc_of("ACCESS_DESCRIPTION")) + got = ffi_cast("AUTHORITY_INFO_ACCESS*", got_ref) + local lib = require("resty.openssl.x509.extension.info_access") + -- the internal ptr is returned, ie we need to copy it + return lib.dup(got) +end + +-- AUTO GENERATED: EXTENSIONS +function _M:set_info_access(toset) + local lib = require("resty.openssl.x509.extension.info_access") + if lib.istype and not lib.istype(toset) then + return false, "x509:set_info_access: expect a x509.extension.info_access instance at #1" + end + toset = toset.ctx + -- x509v3.h: # define X509V3_ADD_REPLACE 2L + if C.X509_add1_ext_i2d(self.ctx, NID_info_access, toset, 0, 0x2) ~= 1 then + return false, format_error("x509:set_info_access") + end + return true +end + +-- AUTO GENERATED: EXTENSIONS +function _M:set_info_access_critical(crit) + return _M.set_extension_critical(self, NID_info_access, crit) +end + +-- AUTO GENERATED: EXTENSIONS +function _M:get_info_access_critical() + return _M.get_extension_critical(self, NID_info_access) +end + +local NID_crl_distribution_points = C.OBJ_sn2nid("crlDistributionPoints") +assert(NID_crl_distribution_points ~= 0) + +-- AUTO GENERATED: EXTENSIONS +function _M:get_crl_distribution_points() + local crit = ctypes.ptr_of_int() + -- X509_get_ext_d2i returns internal pointer, always dup + -- for now this function always returns the first found extension + local got = C.X509_get_ext_d2i(self.ctx, NID_crl_distribution_points, crit, nil) + crit = tonumber(crit[0]) + if crit == -1 then -- not found + return nil + elseif crit == -2 then + return nil, "x509:get_crl_distribution_points: extension of crl_distribution_points occurs more than one times, " .. + "this is not yet implemented. Please use get_extension instead." + elseif got == nil then + return nil, format_error("x509:get_crl_distribution_points") + end + + -- Note: here we only free the stack itself not elements + -- since there seems no way to increase ref count for a DIST_POINT + -- we left the elements referenced by the new-dup'ed stack + local got_ref = got + ffi_gc(got_ref, stack_lib.gc_of("DIST_POINT")) + got = ffi_cast("OPENSSL_STACK*", got_ref) + local lib = require("resty.openssl.x509.extension.dist_points") + -- the internal ptr is returned, ie we need to copy it + return lib.dup(got) +end + +-- AUTO GENERATED: EXTENSIONS +function _M:set_crl_distribution_points(toset) + local lib = require("resty.openssl.x509.extension.dist_points") + if lib.istype and not lib.istype(toset) then + return false, "x509:set_crl_distribution_points: expect a x509.extension.dist_points instance at #1" + end + toset = toset.ctx + -- x509v3.h: # define X509V3_ADD_REPLACE 2L + if C.X509_add1_ext_i2d(self.ctx, NID_crl_distribution_points, toset, 0, 0x2) ~= 1 then + return false, format_error("x509:set_crl_distribution_points") + end + return true +end + +-- AUTO GENERATED: EXTENSIONS +function _M:set_crl_distribution_points_critical(crit) + return _M.set_extension_critical(self, NID_crl_distribution_points, crit) +end + +-- AUTO GENERATED: EXTENSIONS +function _M:get_crl_distribution_points_critical() + return _M.get_extension_critical(self, NID_crl_distribution_points) +end + + +-- AUTO GENERATED +function _M:get_signature_nid() + local nid = accessors.get_signature_nid(self.ctx) + if nid <= 0 then + return nil, format_error("x509:get_signature_nid") + end + + return nid +end + +-- AUTO GENERATED +function _M:get_signature_name() + local nid = accessors.get_signature_nid(self.ctx) + if nid <= 0 then + return nil, format_error("x509:get_signature_name") + end + + return ffi.string(C.OBJ_nid2sn(nid)) +end + +-- AUTO GENERATED +function _M:get_signature_digest_name() + local nid = accessors.get_signature_nid(self.ctx) + if nid <= 0 then + return nil, format_error("x509:get_signature_digest_name") + end + + local nid = find_sigid_algs(nid) + + return ffi.string(C.OBJ_nid2sn(nid)) +end +-- END AUTO GENERATED CODE + +return _M diff --git a/server/resty/openssl/x509/name.lua b/server/resty/openssl/x509/name.lua new file mode 100644 index 0000000..f83fcc1 --- /dev/null +++ b/server/resty/openssl/x509/name.lua @@ -0,0 +1,156 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc +local ffi_str = ffi.string + +require "resty.openssl.include.x509.name" +require "resty.openssl.include.err" +local objects_lib = require "resty.openssl.objects" +local asn1_macro = require "resty.openssl.include.asn1" + +-- local MBSTRING_FLAG = 0x1000 +local MBSTRING_ASC = 0x1001 -- (MBSTRING_FLAG|1) + +local _M = {} + +local x509_name_ptr_ct = ffi.typeof("X509_NAME*") + +-- starts from 0 +local function value_at(ctx, i) + local entry = C.X509_NAME_get_entry(ctx, i) + local obj = C.X509_NAME_ENTRY_get_object(entry) + local ret = objects_lib.obj2table(obj) + + local str = C.X509_NAME_ENTRY_get_data(entry) + if str ~= nil then + ret.blob = ffi_str(asn1_macro.ASN1_STRING_get0_data(str)) + end + + return ret +end + +local function iter(tbl) + local i = 0 + local n = tonumber(C.X509_NAME_entry_count(tbl.ctx)) + return function() + i = i + 1 + if i <= n then + local obj = value_at(tbl.ctx, i-1) + return obj.sn or obj.ln or obj.id, obj + end + end +end + +local mt = { + __index = _M, + __pairs = iter, + __len = function(tbl) return tonumber(C.X509_NAME_entry_count(tbl.ctx)) end, +} + +function _M.new() + local ctx = C.X509_NAME_new() + if ctx == nil then + return nil, "x509.name.new: X509_NAME_new() failed" + end + ffi_gc(ctx, C.X509_NAME_free) + + local self = setmetatable({ + ctx = ctx, + }, mt) + + return self, nil +end + +function _M.istype(l) + return l and l.ctx and ffi.istype(x509_name_ptr_ct, l.ctx) +end + +function _M.dup(ctx) + if not ffi.istype(x509_name_ptr_ct, ctx) then + return nil, "x509.name.dup: expect a x509.name ctx at #1, got " .. type(ctx) + end + local ctx = C.X509_NAME_dup(ctx) + ffi_gc(ctx, C.X509_NAME_free) + + local self = setmetatable({ + ctx = ctx, + }, mt) + + return self, nil +end + +function _M:add(nid, txt) + local asn1 = C.OBJ_txt2obj(nid, 0) + if asn1 == nil then + -- clean up error occurs during OBJ_txt2* + C.ERR_clear_error() + return nil, "x509.name:add: invalid NID text " .. (nid or "nil") + end + + local code = C.X509_NAME_add_entry_by_OBJ(self.ctx, asn1, MBSTRING_ASC, txt, #txt, -1, 0) + C.ASN1_OBJECT_free(asn1) + + if code ~= 1 then + return nil, "x509.name:add: X509_NAME_add_entry_by_OBJ() failed" + end + + return self +end + +function _M:find(nid, last_pos) + local asn1 = C.OBJ_txt2obj(nid, 0) + if asn1 == nil then + -- clean up error occurs during OBJ_txt2* + C.ERR_clear_error() + return nil, nil, "x509.name:find: invalid NID text " .. (nid or "nil") + end + -- make 1-index array to 0-index + last_pos = (last_pos or 0) - 1 + + local pos = C.X509_NAME_get_index_by_OBJ(self.ctx, asn1, last_pos) + if pos == -1 then + return nil + end + + C.ASN1_OBJECT_free(asn1) + + return value_at(self.ctx, pos), pos+1 +end + +-- fallback function to iterate if LUAJIT_ENABLE_LUA52COMPAT not enabled +function _M:all() + local ret = {} + local _next = iter(self) + while true do + local k, obj = _next() + if obj then + ret[k] = obj + else + break + end + end + return ret +end + +function _M:each() + return iter(self) +end + +mt.__tostring = function(self) + local values = {} + local _next = iter(self) + while true do + local k, v = _next() + if k then + table.insert(values, k .. "=" .. v.blob) + else + break + end + end + table.sort(values) + return table.concat(values, "/") +end + +_M.tostring = mt.__tostring + +return _M diff --git a/server/resty/openssl/x509/revoked.lua b/server/resty/openssl/x509/revoked.lua new file mode 100644 index 0000000..9762200 --- /dev/null +++ b/server/resty/openssl/x509/revoked.lua @@ -0,0 +1,108 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc + +require "resty.openssl.include.x509.crl" +require "resty.openssl.include.x509.revoked" +local bn_lib = require("resty.openssl.bn") +local format_error = require("resty.openssl.err").format_error + +local _M = {} +local mt = { __index = _M } + +local x509_revoked_ptr_ct = ffi.typeof('X509_REVOKED*') + +local NID_crl_reason = C.OBJ_txt2nid("CRLReason") +assert(NID_crl_reason > 0) + +--- Creates new instance of X509_REVOKED data +-- @tparam bn|number sn Serial number as number or bn instance +-- @tparam number time Revocation time +-- @tparam number reason Revocation reason +-- @treturn table instance of the module or nil +-- @treturn[opt] string Returns optional error message in case of error +function _M.new(sn, time, reason) + --- only convert to bn if it is number + if type(sn) == "number"then + sn = bn_lib.new(sn) + end + if not bn_lib.istype(sn) then + return nil, "x509.revoked.new: sn should be number or a bn instance" + end + + if type(time) ~= "number" then + return nil, "x509.revoked.new: expect a number at #2" + end + if type(reason) ~= "number" then + return nil, "x509.revoked.new: expect a number at #3" + end + + local ctx = C.X509_REVOKED_new() + ffi_gc(ctx, C.X509_REVOKED_free) + + -- serial number + local sn_asn1 = C.BN_to_ASN1_INTEGER(sn.ctx, nil) + if sn_asn1 == nil then + return nil, "x509.revoked.new: BN_to_ASN1_INTEGER() failed" + end + ffi_gc(sn_asn1, C.ASN1_INTEGER_free) + + if C.X509_REVOKED_set_serialNumber(ctx, sn_asn1) == 0 then + return nil, format_error("x509.revoked.new: X509_REVOKED_set_serialNumber()") + end + + -- time + time = C.ASN1_TIME_set(nil, time) + if time == nil then + return nil, format_error("x509.revoked.new: ASN1_TIME_set()") + end + ffi_gc(time, C.ASN1_STRING_free) + + if C.X509_REVOKED_set_revocationDate(ctx, time) == 0 then + return nil, format_error("x509.revoked.new: X509_REVOKED_set_revocationDate()") + end + + -- reason + local reason_asn1 = C.ASN1_ENUMERATED_new() + if reason_asn1 == nil then + return nil, "x509.revoked.new: ASN1_ENUMERATED_new() failed" + end + ffi_gc(reason_asn1, C.ASN1_ENUMERATED_free) + + local reason_ext = C.X509_EXTENSION_new() + if reason_ext == nil then + return nil, "x509.revoked.new: X509_EXTENSION_new() failed" + end + ffi_gc(reason_ext, C.X509_EXTENSION_free) + + if C.ASN1_ENUMERATED_set(reason_asn1, reason) == 0 then + return nil, format_error("x509.revoked.new: ASN1_ENUMERATED_set()") + end + + if C.X509_EXTENSION_set_data(reason_ext, reason_asn1) == 0 then + return nil, format_error("x509.revoked.new: X509_EXTENSION_set_data()") + end + + if C.X509_EXTENSION_set_object(reason_ext, C.OBJ_nid2obj(NID_crl_reason)) == 0 then + return nil, format_error("x509.revoked.new: X509_EXTENSION_set_object()") + end + + if C.X509_REVOKED_add_ext(ctx, reason_ext, 0) == 0 then + return nil, format_error("x509.revoked.new: X509_EXTENSION_set_object()") + end + + local self = setmetatable({ + ctx = ctx, + }, mt) + + return self, nil +end + +--- Type check +-- @tparam table Instance of revoked module +-- @treturn boolean true if instance is instance of revoked module false otherwise +function _M.istype(l) + return l and l.ctx and ffi.istype(x509_revoked_ptr_ct, l.ctx) +end + +return _M diff --git a/server/resty/openssl/x509/store.lua b/server/resty/openssl/x509/store.lua new file mode 100644 index 0000000..1722e4c --- /dev/null +++ b/server/resty/openssl/x509/store.lua @@ -0,0 +1,227 @@ +local ffi = require "ffi" +local C = ffi.C +local ffi_gc = ffi.gc +local ffi_str = ffi.string +local bor = bit.bor + +local x509_vfy_macro = require "resty.openssl.include.x509_vfy" +local x509_lib = require "resty.openssl.x509" +local chain_lib = require "resty.openssl.x509.chain" +local crl_lib = require "resty.openssl.x509.crl" +local ctx_lib = require "resty.openssl.ctx" +local format_error = require("resty.openssl.err").format_all_error +local format_all_error = require("resty.openssl.err").format_error +local OPENSSL_3X = require("resty.openssl.version").OPENSSL_3X + +local _M = {} +local mt = { __index = _M } + +_M.verify_flags = x509_vfy_macro.verify_flags + +local x509_store_ptr_ct = ffi.typeof('X509_STORE*') + +function _M.new() + local ctx = C.X509_STORE_new() + if ctx == nil then + return nil, "x509.store.new: X509_STORE_new() failed" + end + ffi_gc(ctx, C.X509_STORE_free) + + local self = setmetatable({ + ctx = ctx, + _elem_refs = {}, + _elem_refs_idx = 1, + }, mt) + + return self, nil +end + +function _M.istype(l) + return l and l.ctx and ffi.istype(x509_store_ptr_ct, l.ctx) +end + +function _M:use_default(properties) + if x509_vfy_macro.X509_STORE_set_default_paths(self.ctx, ctx_lib.get_libctx(), properties) ~= 1 then + return false, format_all_error("x509.store:use_default") + end + return true +end + +function _M:add(item) + local dup + local err + if x509_lib.istype(item) then + dup = C.X509_dup(item.ctx) + if dup == nil then + return false, "x509.store:add: X509_dup() failed" + end + -- ref counter of dup is increased by 1 + if C.X509_STORE_add_cert(self.ctx, dup) ~= 1 then + err = format_all_error("x509.store:add: X509_STORE_add_cert") + end + -- decrease the dup ctx ref count immediately to make leak test happy + C.X509_free(dup) + elseif crl_lib.istype(item) then + dup = C.X509_CRL_dup(item.ctx) + if dup == nil then + return false, "x509.store:add: X509_CRL_dup() failed" + end + -- ref counter of dup is increased by 1 + if C.X509_STORE_add_crl(self.ctx, dup) ~= 1 then + err = format_all_error("x509.store:add: X509_STORE_add_crl") + end + + -- define X509_V_FLAG_CRL_CHECK 0x4 + -- enables CRL checking for the certificate chain leaf certificate. + -- An error occurs if a suitable CRL cannot be found. + -- Note: this does not check for certificates in the chain. + if C.X509_STORE_set_flags(self.ctx, 0x4) ~= 1 then + return false, format_error("x509.store:add: X509_STORE_set_flags") + end + -- decrease the dup ctx ref count immediately to make leak test happy + C.X509_CRL_free(dup) + else + return false, "x509.store:add: expect an x509 or crl instance at #1" + end + + if err then + return false, err + end + + -- X509_STORE doesn't have stack gc handler, we need to gc by ourselves + self._elem_refs[self._elem_refs_idx] = dup + self._elem_refs_idx = self._elem_refs_idx + 1 + + return true +end + +function _M:load_file(path, properties) + if type(path) ~= "string" then + return false, "x509.store:load_file: expect a string at #1" + else + if x509_vfy_macro.X509_STORE_load_locations(self.ctx, path, nil, + ctx_lib.get_libctx(), properties) ~= 1 then + return false, format_all_error("x509.store:load_file") + end + end + + return true +end + +function _M:load_directory(path, properties) + if type(path) ~= "string" then + return false, "x509.store:load_directory expect a string at #1" + else + if x509_vfy_macro.X509_STORE_load_locations(self.ctx, nil, path, + ctx_lib.get_libctx(), properties) ~= 1 then + return false, format_all_error("x509.store:load_directory") + end + end + + return true +end + +function _M:set_depth(depth) + depth = depth and tonumber(depth) + if not depth then + return nil, "x509.store:set_depth: expect a number at #1" + end + + if C.X509_STORE_set_depth(self.ctx, depth) ~= 1 then + return false, format_error("x509.store:set_depth") + end + + return true +end + +function _M:set_purpose(purpose) + if type(purpose) ~= "string" then + return nil, "x509.store:set_purpose: expect a string at #1" + end + + local pchar = ffi.new("char[?]", #purpose, purpose) + local idx = C.X509_PURPOSE_get_by_sname(pchar) + idx = tonumber(idx) + + if idx == -1 then + return false, "invalid purpose \"" .. purpose .. "\"" + end + + local purp = C.X509_PURPOSE_get0(idx) + local i = C.X509_PURPOSE_get_id(purp) + + if C.X509_STORE_set_purpose(self.ctx, i) ~= 1 then + return false, format_error("x509.store:set_purpose: X509_STORE_set_purpose") + end + + return true +end + +function _M:set_flags(...) + local flag = 0 + for _, f in ipairs({...}) do + flag = bor(flag, f) + end + + if C.X509_STORE_set_flags(self.ctx, flag) ~= 1 then + return false, format_error("x509.store:set_flags: X509_STORE_set_flags") + end + + return true +end + +function _M:verify(x509, chain, return_chain, properties, verify_method) + if not x509_lib.istype(x509) then + return nil, "x509.store:verify: expect a x509 instance at #1" + elseif chain and not chain_lib.istype(chain) then + return nil, "x509.store:verify: expect a x509.chain instance at #1" + end + + local ctx + if OPENSSL_3X then + ctx = C.X509_STORE_CTX_new_ex(ctx_lib.get_libctx(), properties) + else + ctx = C.X509_STORE_CTX_new() + end + if ctx == nil then + return nil, "x509.store:verify: X509_STORE_CTX_new() failed" + end + + ffi_gc(ctx, C.X509_STORE_CTX_free) + + local chain_dup_ctx + if chain then + local chain_dup, err = chain_lib.dup(chain.ctx) + if err then + return nil, err + end + chain_dup_ctx = chain_dup.ctx + end + + if C.X509_STORE_CTX_init(ctx, self.ctx, x509.ctx, chain_dup_ctx) ~= 1 then + return nil, format_error("x509.store:verify: X509_STORE_CTX_init") + end + + if verify_method and C.X509_STORE_CTX_set_default(ctx, verify_method) ~= 1 then + return nil, "x509.store:verify: invalid verify_method \"" .. verify_method .. "\"" + end + + local code = C.X509_verify_cert(ctx) + if code == 1 then -- verified + if not return_chain then + return true, nil + end + local ret_chain_ctx = x509_vfy_macro.X509_STORE_CTX_get0_chain(ctx) + return chain_lib.dup(ret_chain_ctx) + elseif code == 0 then -- unverified + local vfy_code = C.X509_STORE_CTX_get_error(ctx) + + return nil, ffi_str(C.X509_verify_cert_error_string(vfy_code)) + end + + -- error + return nil, format_error("x509.store:verify: X509_verify_cert", code) + +end + +return _M diff --git a/server/resty/session.lua b/server/resty/session.lua new file mode 100644 index 0000000..12dfe53 --- /dev/null +++ b/server/resty/session.lua @@ -0,0 +1,771 @@ +local require = require + +local random = require "resty.random" + +local ngx = ngx +local var = ngx.var +local time = ngx.time +local header = ngx.header +local http_time = ngx.http_time +local set_header = ngx.req.set_header +local clear_header = ngx.req.clear_header +local concat = table.concat +local ceil = math.ceil +local max = math.max +local find = string.find +local gsub = string.gsub +local byte = string.byte +local sub = string.sub +local type = type +local pcall = pcall +local tonumber = tonumber +local setmetatable = setmetatable +local getmetatable = getmetatable +local bytes = random.bytes + +local UNDERSCORE = byte("_") +local EXPIRE_FLAGS = "; Expires=Thu, 01 Jan 1970 00:00:01 GMT; Max-Age=0" + +local COOKIE_PARTS = { + DEFAULT = { + n = 3, + "id", + "expires", -- may also contain: `expires:usebefore` + "hash" + }, + cookie = { + n = 4, + "id", + "expires", -- may also contain: `expires:usebefore` + "data", + "hash", + }, +} + +local function enabled(value) + if value == nil then + return nil + end + + return value == true + or value == "1" + or value == "true" + or value == "on" +end + +local function ifnil(value, default) + if value == nil then + return default + end + + return enabled(value) +end + +local function prequire(prefix, package, default) + if type(package) == "table" then + return package, package.name + end + + local ok, module = pcall(require, prefix .. package) + if not ok then + return require(prefix .. default), default + end + + return module, package +end + +local function is_session_cookie(cookie, name, name_len) + if not cookie or cookie == "" then + return false, nil + end + + cookie = gsub(cookie, "^%s+", "") + if cookie == "" then + return false, nil + end + + cookie = gsub(cookie, "%s+$", "") + if cookie == "" then + return false, nil + end + + local eq_pos = find(cookie, "=", 1, true) + if not eq_pos then + return false, cookie + end + + local cookie_name = sub(cookie, 1, eq_pos - 1) + if cookie_name == "" then + return false, cookie + end + + cookie_name = gsub(cookie_name, "%s+$", "") + if cookie_name == "" then + return false, cookie + end + + if cookie_name ~= name then + if find(cookie_name, name, 1, true) ~= 1 then + return false, cookie + end + + if byte(cookie_name, name_len + 1) ~= UNDERSCORE then + return false, cookie + end + + if not tonumber(sub(cookie_name, name_len + 2), 10) then + return false, cookie + end + end + + return true, cookie +end + +local function set_cookie(session, value, expires) + if ngx.headers_sent then + return nil, "attempt to set session cookie after sending out response headers" + end + + value = value or "" + + local cookie = session.cookie + local output = {} + + local i = 3 + + -- build cookie parameters, elements 1+2 will be set later + if expires then + -- we're expiring/deleting the data, so set an expiry in the past + output[i] = EXPIRE_FLAGS + elseif cookie.persistent then + -- persistent cookies have an expiry + output[i] = "; Expires=" .. http_time(session.expires) .. "; Max-Age=" .. cookie.lifetime + else + -- just to reserve index 3 for expiry as cookie might get smaller, + -- and some cookies need to be expired. + output[i] = "" + end + + if cookie.domain and cookie.domain ~= "localhost" and cookie.domain ~= "" then + i = i + 1 + output[i] = "; Domain=" .. cookie.domain + end + + i = i + 1 + output[i] = "; Path=" .. (cookie.path or "/") + + if cookie.samesite == "Lax" + or cookie.samesite == "Strict" + or cookie.samesite == "None" + then + i = i + 1 + output[i] = "; SameSite=" .. cookie.samesite + end + + if cookie.secure then + i = i + 1 + output[i] = "; Secure" + end + + if cookie.httponly then + i = i + 1 + output[i] = "; HttpOnly" + end + + -- How many chunks do we need? + local cookie_parts + local cookie_chunks + if expires then + -- expiring cookie, so deleting data. Do not measure data, but use + -- existing chunk count to make sure we clear all of them + cookie_parts = cookie.chunks or 1 + else + -- calculate required chunks from data + cookie_chunks = max(ceil(#value / cookie.maxsize), 1) + cookie_parts = max(cookie_chunks, cookie.chunks or 1) + end + + local cookie_header = header["Set-Cookie"] + for j = 1, cookie_parts do + -- create numbered chunk names if required + local chunk_name = { session.name } + if j > 1 then + chunk_name[2] = "_" + chunk_name[3] = j + chunk_name[4] = "=" + else + chunk_name[2] = "=" + end + chunk_name = concat(chunk_name) + output[1] = chunk_name + + if expires then + -- expiring cookie, so deleting data; clear it + output[2] = "" + elseif j > cookie_chunks then + -- less chunks than before, clearing excess cookies + output[2] = "" + output[3] = EXPIRE_FLAGS + + else + -- grab the piece for the current chunk + local sp = j * cookie.maxsize - (cookie.maxsize - 1) + if j < cookie_chunks then + output[2] = sub(value, sp, sp + (cookie.maxsize - 1)) .. "0" + else + output[2] = sub(value, sp) + end + end + + -- build header value and add it to the header table/string + -- replace existing chunk-name, or append + local cookie_content = concat(output) + local header_type = type(cookie_header) + if header_type == "table" then + local found = false + local cookie_count = #cookie_header + for cookie_index = 1, cookie_count do + if find(cookie_header[cookie_index], chunk_name, 1, true) == 1 then + cookie_header[cookie_index] = cookie_content + found = true + break + end + end + if not found then + cookie_header[cookie_count + 1] = cookie_content + end + elseif header_type == "string" and find(cookie_header, chunk_name, 1, true) ~= 1 then + cookie_header = { cookie_header, cookie_content } + else + cookie_header = cookie_content + end + end + + header["Set-Cookie"] = cookie_header + + return true +end + +local function get_cookie(session, i) + local cookie_name = { "cookie_", session.name } + if i then + cookie_name[3] = "_" + cookie_name[4] = i + else + i = 1 + end + + local cookie = var[concat(cookie_name)] + if not cookie then + return nil + end + + session.cookie.chunks = i + + local cookie_size = #cookie + if cookie_size <= session.cookie.maxsize then + return cookie + end + + return concat{ sub(cookie, 1, session.cookie.maxsize), get_cookie(session, i + 1) or "" } +end + +local function set_usebefore(session) + local usebefore = session.usebefore + local idletime = session.cookie.idletime + + if idletime == 0 then -- usebefore is disabled + if usebefore then + session.usebefore = nil + return true + end + + return false + end + + usebefore = usebefore or 0 + + local new_usebefore = session.now + idletime + if new_usebefore - usebefore > 60 then + session.usebefore = new_usebefore + return true + end + + return false +end + +local function save(session, close) + session.expires = session.now + session.cookie.lifetime + + set_usebefore(session) + + local cookie, err = session.strategy.save(session, close) + if not cookie then + return nil, err or "unable to save session cookie" + end + + return set_cookie(session, cookie) +end + +local function touch(session, close) + if set_usebefore(session) then + -- usebefore was updated, so set cookie + local cookie, err = session.strategy.touch(session, close) + if not cookie then + return nil, err or "unable to touch session cookie" + end + + return set_cookie(session, cookie) + end + + if close then + local ok, err = session.strategy.close(session) + if not ok then + return nil, err + end + end + + return true +end + +local function regenerate(session, flush) + if session.strategy.destroy then + session.strategy.destroy(session) + elseif session.strategy.close then + session.strategy.close(session) + end + + if flush then + session.data = {} + end + + session.id = session:identifier() +end + +local secret = bytes(32, true) or bytes(32) +local defaults + +local function init() + defaults = { + name = var.session_name or "session", + identifier = var.session_identifier or "random", + strategy = var.session_strategy or "default", + storage = var.session_storage or "cookie", + serializer = var.session_serializer or "json", + compressor = var.session_compressor or "none", + encoder = var.session_encoder or "base64", + cipher = var.session_cipher or "aes", + hmac = var.session_hmac or "sha1", + cookie = { + path = var.session_cookie_path or "/", + domain = var.session_cookie_domain, + samesite = var.session_cookie_samesite or "Lax", + secure = enabled(var.session_cookie_secure), + httponly = enabled(var.session_cookie_httponly or true), + persistent = enabled(var.session_cookie_persistent or false), + discard = tonumber(var.session_cookie_discard, 10) or 10, + renew = tonumber(var.session_cookie_renew, 10) or 600, + lifetime = tonumber(var.session_cookie_lifetime, 10) or 3600, + idletime = tonumber(var.session_cookie_idletime, 10) or 0, + maxsize = tonumber(var.session_cookie_maxsize, 10) or 4000, + + }, check = { + ssi = enabled(var.session_check_ssi or false), + ua = enabled(var.session_check_ua or true), + scheme = enabled(var.session_check_scheme or true), + addr = enabled(var.session_check_addr or false) + } + } + defaults.secret = var.session_secret or secret +end + +local session = { + _VERSION = "3.10" +} + +session.__index = session + +function session:get_cookie() + return get_cookie(self) +end + +function session:parse_cookie(value) + local cookie + local cookie_parts = COOKIE_PARTS[self.cookie.storage] or COOKIE_PARTS.DEFAULT + + local count = 1 + local pos = 1 + + local p_pos = find(value, "|", 1, true) + while p_pos do + if count > (cookie_parts.n - 1) then + return nil, "too many session cookie parts" + end + if not cookie then + cookie = {} + end + + if count == 2 then + local cookie_part = sub(value, pos, p_pos - 1) + local c_pos = find(cookie_part, ":", 2, true) + if c_pos then + cookie.expires = tonumber(sub(cookie_part, 1, c_pos - 1), 10) + if not cookie.expires then + return nil, "invalid session cookie expiry" + end + + cookie.usebefore = tonumber(sub(cookie_part, c_pos + 1), 10) + if not cookie.usebefore then + return nil, "invalid session cookie usebefore" + end + else + cookie.expires = tonumber(cookie_part, 10) + if not cookie.expires then + return nil, "invalid session cookie expiry" + end + end + else + local name = cookie_parts[count] + + local cookie_part = self.encoder.decode(sub(value, pos, p_pos - 1)) + if not cookie_part then + return nil, "unable to decode session cookie part (" .. name .. ")" + end + + cookie[name] = cookie_part + end + + count = count + 1 + pos = p_pos + 1 + + p_pos = find(value, "|", pos, true) + end + + if count ~= cookie_parts.n then + return nil, "invalid number of session cookie parts" + end + + local name = cookie_parts[count] + + local cookie_part = self.encoder.decode(sub(value, pos)) + if not cookie_part then + return nil, "unable to decode session cookie part (" .. name .. ")" + end + + cookie[name] = cookie_part + + if not cookie.id then + return nil, "missing session cookie id" + end + + if not cookie.expires then + return nil, "missing session cookie expiry" + end + + if cookie.expires <= self.now then + return nil, "session cookie has expired" + end + + if cookie.usebefore and cookie.usebefore <= self.now then + return nil, "session cookie idle time has passed" + end + + if not cookie.hash then + return nil, "missing session cookie signature" + end + + return cookie +end + +function session.new(opts) + if opts and getmetatable(opts) == session then + return opts + end + + if not defaults then + init() + end + + opts = type(opts) == "table" and opts or defaults + + local cookie = opts.cookie or defaults.cookie + local name = opts.name or defaults.name + local sec = opts.secret or defaults.secret + + local secure + local path + local domain + if find(name, "__Host-", 1, true) == 1 then + secure = true + path = "/" + else + if find(name, "__Secure-", 1, true) == 1 then + secure = true + else + secure = ifnil(cookie.secure, defaults.cookie.secure) + end + + domain = cookie.domain or defaults.cookie.domain + path = cookie.path or defaults.cookie.path + end + + local check = opts.check or defaults.check + + local ide, iden = prequire("resty.session.identifiers.", opts.identifier or defaults.identifier, "random") + local ser, sern = prequire("resty.session.serializers.", opts.serializer or defaults.serializer, "json") + local com, comn = prequire("resty.session.compressors.", opts.compressor or defaults.compressor, "none") + local enc, encn = prequire("resty.session.encoders.", opts.encoder or defaults.encoder, "base64") + local cip, cipn = prequire("resty.session.ciphers.", opts.cipher or defaults.cipher, "aes") + local sto, ston = prequire("resty.session.storage.", opts.storage or defaults.storage, "cookie") + local str, strn = prequire("resty.session.strategies.", opts.strategy or defaults.strategy, "default") + local hma, hman = prequire("resty.session.hmac.", opts.hmac or defaults.hmac, "sha1") + + local self = { + now = time(), + name = name, + secret = sec, + identifier = ide, + serializer = ser, + strategy = str, + encoder = enc, + hmac = hma, + cookie = { + storage = ston, + encoder = enc, + path = path, + domain = domain, + secure = secure, + samesite = cookie.samesite or defaults.cookie.samesite, + httponly = ifnil(cookie.httponly, defaults.cookie.httponly), + persistent = ifnil(cookie.persistent, defaults.cookie.persistent), + discard = tonumber(cookie.discard, 10) or defaults.cookie.discard, + renew = tonumber(cookie.renew, 10) or defaults.cookie.renew, + lifetime = tonumber(cookie.lifetime, 10) or defaults.cookie.lifetime, + idletime = tonumber(cookie.idletime, 10) or defaults.cookie.idletime, + maxsize = tonumber(cookie.maxsize, 10) or defaults.cookie.maxsize, + }, check = { + ssi = ifnil(check.ssi, defaults.check.ssi), + ua = ifnil(check.ua, defaults.check.ua), + scheme = ifnil(check.scheme, defaults.check.scheme), + addr = ifnil(check.addr, defaults.check.addr), + } + } + if self.cookie.idletime > 0 and self.cookie.discard > self.cookie.idletime then + -- if using idletime, then the discard period must be less or equal + self.cookie.discard = self.cookie.idletime + end + + if iden and not self[iden] then self[iden] = opts[iden] end + if sern and not self[sern] then self[sern] = opts[sern] end + if comn and not self[comn] then self[comn] = opts[comn] end + if encn and not self[encn] then self[encn] = opts[encn] end + if cipn and not self[cipn] then self[cipn] = opts[cipn] end + if ston and not self[ston] then self[ston] = opts[ston] end + if strn and not self[strn] then self[strn] = opts[strn] end + if hman and not self[hman] then self[hman] = opts[hman] end + + self.cipher = cip.new(self) + self.storage = sto.new(self) + self.compressor = com.new(self) + + return setmetatable(self, session) +end + +function session.open(opts, keep_lock) + local self = opts + if self and getmetatable(self) == session then + if self.opened then + return self, self.present + end + else + self = session.new(opts) + end + + if self.cookie.secure == nil then + self.cookie.secure = var.scheme == "https" or var.https == "on" + end + + self.now = time() + self.key = concat { + self.check.ssi and var.ssl_session_id or "", + self.check.ua and var.http_user_agent or "", + self.check.addr and var.remote_addr or "", + self.check.scheme and var.scheme or "", + } + + self.opened = true + + local err + local cookie = self:get_cookie() + if cookie then + cookie, err = self:parse_cookie(cookie) + if cookie then + local ok + ok, err = self.strategy.open(self, cookie, keep_lock) + if ok then + return self, true + end + end + end + + regenerate(self, true) + + return self, false, err +end + +function session.start(opts) + if opts and getmetatable(opts) == session and opts.started then + return opts, opts.present + end + + local self, present, reason = session.open(opts, true) + + self.started = true + + if not present then + local ok, err = save(self) + if not ok then + return nil, err or "unable to save session cookie" + end + + return self, present, reason + end + + if self.strategy.start then + local ok, err = self.strategy.start(self) + if not ok then + return nil, err or "unable to start session" + end + end + + if self.expires - self.now < self.cookie.renew + or self.expires > self.now + self.cookie.lifetime + then + local ok, err = save(self) + if not ok then + return nil, err or "unable to save session cookie" + end + else + -- we're not saving, so we must touch to update idletime/usebefore + local ok, err = touch(self) + if not ok then + return nil, err or "unable to touch session cookie" + end + end + + return self, true +end + +function session.destroy(opts) + if opts and getmetatable(opts) == session and opts.destroyed then + return true + end + + local self, err = session.start(opts) + if not self then + return nil, err + end + + if self.strategy.destroy then + self.strategy.destroy(self) + elseif self.strategy.close then + self.strategy.close(self) + end + + self.data = {} + self.present = nil + self.opened = nil + self.started = nil + self.closed = true + self.destroyed = true + + return set_cookie(self, "", true) +end + +function session:regenerate(flush, close) + close = close ~= false + if self.strategy.regenerate then + if flush then + self.data = {} + end + + if not self.id then + self.id = session:identifier() + end + else + regenerate(self, flush) + end + + return save(self, close) +end + +function session:save(close) + close = close ~= false + + if not self.id then + self.id = self:identifier() + end + + return save(self, close) +end + +function session:close() + self.closed = true + + if self.strategy.close then + return self.strategy.close(self) + end + + return true +end + +function session:hide() + local cookies = var.http_cookie + if not cookies or cookies == "" then + return + end + + local results = {} + local name = self.name + local name_len = #name + local found + local i = 1 + local j = 0 + local sc_pos = find(cookies, ";", i, true) + while sc_pos do + local isc, cookie = is_session_cookie(sub(cookies, i, sc_pos - 1), name, name_len) + if isc then + found = true + elseif cookie then + j = j + 1 + results[j] = cookie + end + + i = sc_pos + 1 + sc_pos = find(cookies, ";", i, true) + end + + local isc, cookie + if i == 1 then + isc, cookie = is_session_cookie(cookies, name, name_len) + else + isc, cookie = is_session_cookie(sub(cookies, i), name, name_len) + end + + if not isc and cookie then + if not found then + return + end + + j = j + 1 + results[j] = cookie + end + + if j == 0 then + clear_header("Cookie") + else + set_header("Cookie", concat(results, "; ", 1, j)) + end +end + +return session diff --git a/server/resty/session/ciphers/aes.lua b/server/resty/session/ciphers/aes.lua new file mode 100644 index 0000000..9a088ad --- /dev/null +++ b/server/resty/session/ciphers/aes.lua @@ -0,0 +1,113 @@ +local aes = require "resty.aes" + +local setmetatable = setmetatable +local tonumber = tonumber +local ceil = math.ceil +local var = ngx.var +local sub = string.sub +local rep = string.rep + +local HASHES = aes.hash + +local CIPHER_MODES = { + ecb = "ecb", + cbc = "cbc", + cfb1 = "cfb1", + cfb8 = "cfb8", + cfb128 = "cfb128", + ofb = "ofb", + ctr = "ctr", + gcm = "gcm", +} + +local CIPHER_SIZES = { + [128] = 128, + [192] = 192, + [256] = 256, +} + +local defaults = { + size = CIPHER_SIZES[tonumber(var.session_aes_size, 10)] or 256, + mode = CIPHER_MODES[var.session_aes_mode] or "cbc", + hash = HASHES[var.session_aes_hash] or HASHES.sha512, + rounds = tonumber(var.session_aes_rounds, 10) or 1, +} + +local function adjust_salt(salt) + if not salt then + return nil + end + + local z = #salt + if z < 8 then + return sub(rep(salt, ceil(8 / z)), 1, 8) + end + if z > 8 then + return sub(salt, 1, 8) + end + + return salt +end + +local function get_cipher(self, key, salt) + local mode = aes.cipher(self.size, self.mode) + if not mode then + return nil, "invalid cipher mode " .. self.mode .. "(" .. self.size .. ")" + end + + return aes:new(key, adjust_salt(salt), mode, self.hash, self.rounds) +end + +local cipher = {} + +cipher.__index = cipher + +function cipher.new(session) + local config = session.aes or defaults + return setmetatable({ + size = CIPHER_SIZES[tonumber(config.size, 10)] or defaults.size, + mode = CIPHER_MODES[config.mode] or defaults.mode, + hash = HASHES[config.hash] or defaults.hash, + rounds = tonumber(config.rounds, 10) or defaults.rounds, + }, cipher) +end + +function cipher:encrypt(data, key, salt, _) + local cip, err = get_cipher(self, key, salt) + if not cip then + return nil, err or "unable to aes encrypt data" + end + + local encrypted_data + encrypted_data, err = cip:encrypt(data) + if not encrypted_data then + return nil, err or "aes encryption failed" + end + + if self.mode == "gcm" then + return encrypted_data[1], nil, encrypted_data[2] + end + + return encrypted_data +end + +function cipher:decrypt(data, key, salt, _, tag) + local cip, err = get_cipher(self, key, salt) + if not cip then + return nil, err or "unable to aes decrypt data" + end + + local decrypted_data + decrypted_data, err = cip:decrypt(data, tag) + if not decrypted_data then + return nil, err or "aes decryption failed" + end + + if self.mode == "gcm" then + return decrypted_data, nil, tag + end + + return decrypted_data +end + +return cipher diff --git a/server/resty/session/ciphers/none.lua b/server/resty/session/ciphers/none.lua new file mode 100644 index 0000000..b29bb88 --- /dev/null +++ b/server/resty/session/ciphers/none.lua @@ -0,0 +1,15 @@ +local cipher = {} + +function cipher.new() + return cipher +end + +function cipher.encrypt(_, data, _, _) + return data +end + +function cipher.decrypt(_, data, _, _, _) + return data +end + +return cipher diff --git a/server/resty/session/compressors/none.lua b/server/resty/session/compressors/none.lua new file mode 100644 index 0000000..3d14a5c --- /dev/null +++ b/server/resty/session/compressors/none.lua @@ -0,0 +1,15 @@ +local compressor = {} + +function compressor.new() + return compressor +end + +function compressor.compress(_, data) + return data +end + +function compressor.decompress(_, data) + return data +end + +return compressor diff --git a/server/resty/session/compressors/zlib.lua b/server/resty/session/compressors/zlib.lua new file mode 100644 index 0000000..1d23be0 --- /dev/null +++ b/server/resty/session/compressors/zlib.lua @@ -0,0 +1,43 @@ +local zlib = require "ffi-zlib" +local sio = require "pl.stringio" + +local concat = table.concat + +local function gzip(func, input) + local stream = sio.open(input) + local output = {} + local n = 0 + + local ok, err = func(function(size) + return stream:read(size) + end, function(data) + n = n + 1 + output[n] = data + end, 8192) + + if not ok then + return nil, err + end + + if n == 0 then + return "" + end + + return concat(output, nil, 1, n) +end + +local compressor = {} + +function compressor.new() + return compressor +end + +function compressor.compress(_, data) + return gzip(zlib.deflateGzip, data) +end + +function compressor.decompress(_, data) + return gzip(zlib.inflateGzip, data) +end + +return compressor diff --git a/server/resty/session/encoders/base16.lua b/server/resty/session/encoders/base16.lua new file mode 100644 index 0000000..552f50e --- /dev/null +++ b/server/resty/session/encoders/base16.lua @@ -0,0 +1,29 @@ +local to_hex = require "resty.string".to_hex + +local tonumber = tonumber +local gsub = string.gsub +local char = string.char + +local function chr(c) + return char(tonumber(c, 16) or 0) +end + +local encoder = {} + +function encoder.encode(value) + if not value then + return nil, "unable to base16 encode value" + end + + return to_hex(value) +end + +function encoder.decode(value) + if not value then + return nil, "unable to base16 decode value" + end + + return (gsub(value, "..", chr)) +end + +return encoder diff --git a/server/resty/session/encoders/base64.lua b/server/resty/session/encoders/base64.lua new file mode 100644 index 0000000..ddaf4e8 --- /dev/null +++ b/server/resty/session/encoders/base64.lua @@ -0,0 +1,39 @@ +local encode_base64 = ngx.encode_base64 +local decode_base64 = ngx.decode_base64 + +local gsub = string.gsub + +local ENCODE_CHARS = { + ["+"] = "-", + ["/"] = "_", +} + +local DECODE_CHARS = { + ["-"] = "+", + ["_"] = "/", +} + +local encoder = {} + +function encoder.encode(value) + if not value then + return nil, "unable to base64 encode value" + end + + local encoded = encode_base64(value, true) + if not encoded then + return nil, "unable to base64 encode value" + end + + return gsub(encoded, "[+/]", ENCODE_CHARS) +end + +function encoder.decode(value) + if not value then + return nil, "unable to base64 decode value" + end + + return decode_base64((gsub(value, "[-_]", DECODE_CHARS))) +end + +return encoder diff --git a/server/resty/session/encoders/hex.lua b/server/resty/session/encoders/hex.lua new file mode 100644 index 0000000..1b94a5a --- /dev/null +++ b/server/resty/session/encoders/hex.lua @@ -0,0 +1 @@ +return require "resty.session.encoders.base16"
\ No newline at end of file diff --git a/server/resty/session/hmac/sha1.lua b/server/resty/session/hmac/sha1.lua new file mode 100644 index 0000000..1753412 --- /dev/null +++ b/server/resty/session/hmac/sha1.lua @@ -0,0 +1 @@ +return ngx.hmac_sha1 diff --git a/server/resty/session/identifiers/random.lua b/server/resty/session/identifiers/random.lua new file mode 100644 index 0000000..a2f9739 --- /dev/null +++ b/server/resty/session/identifiers/random.lua @@ -0,0 +1,13 @@ +local tonumber = tonumber +local random = require "resty.random".bytes +local var = ngx.var + +local defaults = { + length = tonumber(var.session_random_length, 10) or 16 +} + +return function(session) + local config = session.random or defaults + local length = tonumber(config.length, 10) or defaults.length + return random(length, true) or random(length) +end diff --git a/server/resty/session/serializers/json.lua b/server/resty/session/serializers/json.lua new file mode 100644 index 0000000..960c4d8 --- /dev/null +++ b/server/resty/session/serializers/json.lua @@ -0,0 +1,6 @@ +local json = require "cjson.safe" + +return { + serialize = json.encode, + deserialize = json.decode, +} diff --git a/server/resty/session/storage/cookie.lua b/server/resty/session/storage/cookie.lua new file mode 100644 index 0000000..95e26d1 --- /dev/null +++ b/server/resty/session/storage/cookie.lua @@ -0,0 +1,7 @@ +local storage = {} + +function storage.new() + return storage +end + +return storage diff --git a/server/resty/session/storage/dshm.lua b/server/resty/session/storage/dshm.lua new file mode 100644 index 0000000..e6d887f --- /dev/null +++ b/server/resty/session/storage/dshm.lua @@ -0,0 +1,163 @@ +local dshm = require "resty.dshm" + +local setmetatable = setmetatable +local tonumber = tonumber +local concat = table.concat +local var = ngx.var + +local defaults = { + region = var.session_dshm_region or "sessions", + connect_timeout = tonumber(var.session_dshm_connect_timeout, 10), + read_timeout = tonumber(var.session_dshm_read_timeout, 10), + send_timeout = tonumber(var.session_dshm_send_timeout, 10), + host = var.session_dshm_host or "127.0.0.1", + port = tonumber(var.session_dshm_port, 10) or 4321, + pool = { + name = var.session_dshm_pool_name, + size = tonumber(var.session_dshm_pool_size, 10) or 100, + timeout = tonumber(var.session_dshm_pool_timeout, 10) or 1000, + backlog = tonumber(var.session_dshm_pool_backlog, 10), + }, +} + +local storage = {} + +storage.__index = storage + +function storage.new(session) + local config = session.dshm or defaults + local pool = config.pool or defaults.pool + + local connect_timeout = tonumber(config.connect_timeout, 10) or defaults.connect_timeout + + local store = dshm:new() + if store.set_timeouts then + local send_timeout = tonumber(config.send_timeout, 10) or defaults.send_timeout + local read_timeout = tonumber(config.read_timeout, 10) or defaults.read_timeout + + if connect_timeout then + if send_timeout and read_timeout then + store:set_timeouts(connect_timeout, send_timeout, read_timeout) + else + store:set_timeout(connect_timeout) + end + end + + elseif store.set_timeout and connect_timeout then + store:set_timeout(connect_timeout) + end + + + local self = { + store = store, + encoder = session.encoder, + region = config.region or defaults.region, + host = config.host or defaults.host, + port = tonumber(config.port, 10) or defaults.port, + pool_timeout = tonumber(pool.timeout, 10) or defaults.pool.timeout, + connect_opts = { + pool = pool.name or defaults.pool.name, + pool_size = tonumber(pool.size, 10) or defaults.pool.size, + backlog = tonumber(pool.backlog, 10) or defaults.pool.backlog, + }, + } + + return setmetatable(self, storage) +end + +function storage:connect() + return self.store:connect(self.host, self.port, self.connect_opts) +end + +function storage:set_keepalive() + return self.store:set_keepalive(self.pool_timeout) +end + +function storage:key(id) + return concat({ self.region, id }, "::") +end + +function storage:set(key, ttl, data) + local ok, err = self:connect() + if not ok then + return nil, err + end + + data, err = self.encoder.encode(data) + + if not data then + self:set_keepalive() + return nil, err + end + + ok, err = self.store:set(key, data, ttl) + + self:set_keepalive() + + return ok, err +end + +function storage:get(key) + local ok, err = self:connect() + if not ok then + return nil, err + end + + local data + data, err = self.store:get(key) + if data then + data, err = self.encoder.decode(data) + end + + self:set_keepalive() + + return data, err +end + +function storage:delete(key) + local ok, err = self:connect() + if not ok then + return nil, err + end + + ok, err = self.store:delete(key) + + self:set_keepalive() + + return ok, err +end + +function storage:touch(key, ttl) + local ok, err = self:connect() + if not ok then + return nil, err + end + + ok, err = self.store:touch(key, ttl) + + self:set_keepalive() + + return ok, err +end + +function storage:open(id) + local key = self:key(id) + return self:get(key) +end + +function storage:save(id, ttl, data) + local key = self:key(id) + return self:set(key, ttl, data) +end + +function storage:destroy(id) + local key = self:key(id) + return self:delete(key) +end + +function storage:ttl(id, ttl) + local key = self:key(id) + return self:touch(key, ttl) +end + +return storage diff --git a/server/resty/session/storage/memcache.lua b/server/resty/session/storage/memcache.lua new file mode 100644 index 0000000..da44ba7 --- /dev/null +++ b/server/resty/session/storage/memcache.lua @@ -0,0 +1,303 @@ +local memcached = require "resty.memcached" +local setmetatable = setmetatable +local tonumber = tonumber +local concat = table.concat +local sleep = ngx.sleep +local null = ngx.null +local var = ngx.var + +local function enabled(value) + if value == nil then + return nil + end + + return value == true + or value == "1" + or value == "true" + or value == "on" +end + +local function ifnil(value, default) + if value == nil then + return default + end + + return enabled(value) +end + +local defaults = { + prefix = var.session_memcache_prefix or "sessions", + socket = var.session_memcache_socket, + host = var.session_memcache_host or "127.0.0.1", + uselocking = enabled(var.session_memcache_uselocking or true), + connect_timeout = tonumber(var.session_memcache_connect_timeout, 10), + read_timeout = tonumber(var.session_memcache_read_timeout, 10), + send_timeout = tonumber(var.session_memcache_send_timeout, 10), + port = tonumber(var.session_memcache_port, 10) or 11211, + spinlockwait = tonumber(var.session_memcache_spinlockwait, 10) or 150, + maxlockwait = tonumber(var.session_memcache_maxlockwait, 10) or 30, + pool = { + name = var.session_memcache_pool_name, + timeout = tonumber(var.session_memcache_pool_timeout, 10), + size = tonumber(var.session_memcache_pool_size, 10), + backlog = tonumber(var.session_memcache_pool_backlog, 10), + }, +} + +local storage = {} + +storage.__index = storage + +function storage.new(session) + local config = session.memcache or defaults + local pool = config.pool or defaults.pool + local locking = ifnil(config.uselocking, defaults.uselocking) + + local connect_timeout = tonumber(config.connect_timeout, 10) or defaults.connect_timeout + + local memcache = memcached:new() + if memcache.set_timeouts then + local send_timeout = tonumber(config.send_timeout, 10) or defaults.send_timeout + local read_timeout = tonumber(config.read_timeout, 10) or defaults.read_timeout + + if connect_timeout then + if send_timeout and read_timeout then + memcache:set_timeouts(connect_timeout, send_timeout, read_timeout) + else + memcache:set_timeout(connect_timeout) + end + end + + elseif memcache.set_timeout and connect_timeout then + memcache:set_timeout(connect_timeout) + end + + local self = { + memcache = memcache, + prefix = config.prefix or defaults.prefix, + uselocking = locking, + spinlockwait = tonumber(config.spinlockwait, 10) or defaults.spinlockwait, + maxlockwait = tonumber(config.maxlockwait, 10) or defaults.maxlockwait, + pool_timeout = tonumber(pool.timeout, 10) or defaults.pool.timeout, + connect_opts = { + pool = pool.name or defaults.pool.name, + pool_size = tonumber(pool.size, 10) or defaults.pool.size, + backlog = tonumber(pool.backlog, 10) or defaults.pool.backlog, + }, + } + + local socket = config.socket or defaults.socket + if socket and socket ~= "" then + self.socket = socket + else + self.host = config.host or defaults.host + self.port = config.port or defaults.port + end + + return setmetatable(self, storage) +end + +function storage:connect() + local socket = self.socket + if socket then + return self.memcache:connect(socket, self.connect_opts) + end + return self.memcache:connect(self.host, self.port, self.connect_opts) +end + +function storage:set_keepalive() + return self.memcache:set_keepalive(self.pool_timeout) +end + +function storage:key(id) + return concat({ self.prefix, id }, ":" ) +end + +function storage:lock(key) + if not self.uselocking or self.locked then + return true + end + + if not self.token then + self.token = var.request_id + end + + local lock_key = concat({ key, "lock" }, "." ) + local lock_ttl = self.maxlockwait + 1 + local attempts = (1000 / self.spinlockwait) * self.maxlockwait + local waittime = self.spinlockwait / 1000 + + for _ = 1, attempts do + local ok = self.memcache:add(lock_key, self.token, lock_ttl) + if ok then + self.locked = true + return true + end + + sleep(waittime) + end + + return false, "unable to acquire a session lock" +end + +function storage:unlock(key) + if not self.uselocking or not self.locked then + return true + end + + local lock_key = concat({ key, "lock" }, "." ) + local token = self:get(lock_key) + + if token == self.token then + self.memcache:delete(lock_key) + self.locked = nil + end +end + +function storage:get(key) + local data, err = self.memcache:get(key) + if not data then + return nil, err + end + + if data == null then + return nil + end + + return data +end + +function storage:set(key, data, ttl) + return self.memcache:set(key, data, ttl) +end + +function storage:expire(key, ttl) + return self.memcache:touch(key, ttl) +end + +function storage:delete(key) + return self.memcache:delete(key) +end + +function storage:open(id, keep_lock) + local ok, err = self:connect() + if not ok then + return nil, err + end + + local key = self:key(id) + + ok, err = self:lock(key) + if not ok then + self:set_keepalive() + return nil, err + end + + local data + data, err = self:get(key) + + if err or not data or not keep_lock then + self:unlock(key) + end + + self:set_keepalive() + + return data, err +end + +function storage:start(id) + if not self.uselocking or not self.locked then + return true + end + + local ok, err = self:connect() + if not ok then + return nil, err + end + + local key = self:key(id) + + ok, err = self:lock(key) + + self:set_keepalive() + + return ok, err +end + +function storage:save(id, ttl, data, close) + local ok, err = self:connect() + if not ok then + return nil, err + end + + local key = self:key(id) + + ok, err = self:set(key, data, ttl) + + if close then + self:unlock(key) + end + + self:set_keepalive() + + if not ok then + return nil, err + end + + return true +end + +function storage:close(id) + if not self.uselocking or not self.locked then + return true + end + + local ok, err = self:connect() + if not ok then + return nil, err + end + + local key = self:key(id) + + self:unlock(key) + self:set_keepalive() + + return true +end + +function storage:destroy(id) + local ok, err = self:connect() + if not ok then + return nil, err + end + + local key = self:key(id) + + ok, err = self:delete(key) + + self:unlock(key) + self:set_keepalive() + + return ok, err +end + +function storage:ttl(id, ttl, close) + local ok, err = self:connect() + if not ok then + return nil, err + end + + local key = self:key(id) + + ok, err = self:expire(key, ttl) + + if close then + self:unlock(key) + end + + self:set_keepalive() + + return ok, err +end + +return storage diff --git a/server/resty/session/storage/memcached.lua b/server/resty/session/storage/memcached.lua new file mode 100644 index 0000000..0ecc508 --- /dev/null +++ b/server/resty/session/storage/memcached.lua @@ -0,0 +1 @@ +return require "resty.session.storage.memcache" diff --git a/server/resty/session/storage/redis.lua b/server/resty/session/storage/redis.lua new file mode 100644 index 0000000..3de0472 --- /dev/null +++ b/server/resty/session/storage/redis.lua @@ -0,0 +1,478 @@ +local setmetatable = setmetatable +local tonumber = tonumber +local type = type +local reverse = string.reverse +local gmatch = string.gmatch +local find = string.find +local byte = string.byte +local sub = string.sub +local concat = table.concat +local sleep = ngx.sleep +local null = ngx.null +local var = ngx.var + +local LB = byte("[") +local RB = byte("]") + +local function parse_cluster_nodes(nodes) + if not nodes or nodes == "" then + return nil + end + + if type(nodes) == "table" then + return nodes + end + + local addrs + local i + for node in gmatch(nodes, "%S+") do + local ip = node + local port = 6379 + local pos = find(reverse(ip), ":", 2, true) + if pos then + local p = tonumber(sub(ip, -pos + 1), 10) + if p >= 1 and p <= 65535 then + local addr = sub(ip, 1, -pos - 1) + if find(addr, ":", 1, true) then + if byte(addr, -1) == RB then + ip = addr + port = p + end + + else + ip = addr + port = p + end + end + end + + if byte(ip, 1, 1) == LB then + ip = sub(ip, 2) + end + + if byte(ip, -1) == RB then + ip = sub(ip, 1, -2) + end + + if not addrs then + i = 1 + addrs = {{ + ip = ip, + port = port, + }} + else + i = i + 1 + addrs[i] = { + ip = ip, + port = port, + } + end + end + + if not i then + return + end + + return addrs +end + +local redis_single = require "resty.redis" +local redis_cluster +do + local pcall = pcall + local require = require + local ok + ok, redis_cluster = pcall(require, "resty.rediscluster") + if not ok then + ok, redis_cluster = pcall(require, "rediscluster") + if not ok then + redis_cluster = nil + end + end +end + +local UNLOCK = [[ +if redis.call("GET", KEYS[1]) == ARGV[1] then + return redis.call("DEL", KEYS[1]) +else + return 0 +end +]] + +local function enabled(value) + if value == nil then return nil end + return value == true or (value == "1" or value == "true" or value == "on") +end + +local function ifnil(value, default) + if value == nil then + return default + end + + return enabled(value) +end + +local defaults = { + prefix = var.session_redis_prefix or "sessions", + socket = var.session_redis_socket, + host = var.session_redis_host or "127.0.0.1", + username = var.session_redis_username, + password = var.session_redis_password or var.session_redis_auth, + server_name = var.session_redis_server_name, + ssl = enabled(var.session_redis_ssl) or false, + ssl_verify = enabled(var.session_redis_ssl_verify) or false, + uselocking = enabled(var.session_redis_uselocking or true), + port = tonumber(var.session_redis_port, 10) or 6379, + database = tonumber(var.session_redis_database, 10) or 0, + connect_timeout = tonumber(var.session_redis_connect_timeout, 10), + read_timeout = tonumber(var.session_redis_read_timeout, 10), + send_timeout = tonumber(var.session_redis_send_timeout, 10), + spinlockwait = tonumber(var.session_redis_spinlockwait, 10) or 150, + maxlockwait = tonumber(var.session_redis_maxlockwait, 10) or 30, + pool = { + name = var.session_redis_pool_name, + timeout = tonumber(var.session_redis_pool_timeout, 10), + size = tonumber(var.session_redis_pool_size, 10), + backlog = tonumber(var.session_redis_pool_backlog, 10), + }, +} + + +if redis_cluster then + defaults.cluster = { + name = var.session_redis_cluster_name, + dict = var.session_redis_cluster_dict, + maxredirections = tonumber(var.session_redis_cluster_maxredirections, 10), + nodes = parse_cluster_nodes(var.session_redis_cluster_nodes), + } +end + +local storage = {} + +storage.__index = storage + +function storage.new(session) + local config = session.redis or defaults + local pool = config.pool or defaults.pool + local cluster = config.cluster or defaults.cluster + local locking = ifnil(config.uselocking, defaults.uselocking) + + local self = { + prefix = config.prefix or defaults.prefix, + uselocking = locking, + spinlockwait = tonumber(config.spinlockwait, 10) or defaults.spinlockwait, + maxlockwait = tonumber(config.maxlockwait, 10) or defaults.maxlockwait, + } + + local username = config.username or defaults.username + if username == "" then + username = nil + end + local password = config.password or config.auth or defaults.password + if password == "" then + password = nil + end + + local connect_timeout = tonumber(config.connect_timeout, 10) or defaults.connect_timeout + + local cluster_nodes + if redis_cluster then + cluster_nodes = parse_cluster_nodes(cluster.nodes or defaults.cluster.nodes) + end + + local connect_opts = { + pool = pool.name or defaults.pool.name, + pool_size = tonumber(pool.size, 10) or defaults.pool.size, + backlog = tonumber(pool.backlog, 10) or defaults.pool.backlog, + server_name = config.server_name or defaults.server_name, + ssl = ifnil(config.ssl, defaults.ssl), + ssl_verify = ifnil(config.ssl_verify, defaults.ssl_verify), + } + + if cluster_nodes then + self.redis = redis_cluster:new({ + name = cluster.name or defaults.cluster.name, + dict_name = cluster.dict or defaults.cluster.dict, + username = var.session_redis_username, + password = var.session_redis_password or defaults.password, + connection_timout = connect_timeout, -- typo in library + connection_timeout = connect_timeout, + keepalive_timeout = tonumber(pool.timeout, 10) or defaults.pool.timeout, + keepalive_cons = tonumber(pool.size, 10) or defaults.pool.size, + max_redirection = tonumber(cluster.maxredirections, 10) or defaults.cluster.maxredirections, + serv_list = cluster_nodes, + connect_opts = connect_opts, + }) + self.cluster = true + + else + local redis = redis_single:new() + + if redis.set_timeouts then + local send_timeout = tonumber(config.send_timeout, 10) or defaults.send_timeout + local read_timeout = tonumber(config.read_timeout, 10) or defaults.read_timeout + + if connect_timeout then + if send_timeout and read_timeout then + redis:set_timeouts(connect_timeout, send_timeout, read_timeout) + else + redis:set_timeout(connect_timeout) + end + end + + elseif redis.set_timeout and connect_timeout then + redis:set_timeout(connect_timeout) + end + + self.redis = redis + self.username = username + self.password = password + self.database = tonumber(config.database, 10) or defaults.database + self.pool_timeout = tonumber(pool.timeout, 10) or defaults.pool.timeout + self.connect_opts = connect_opts + + local socket = config.socket or defaults.socket + if socket and socket ~= "" then + self.socket = socket + else + self.host = config.host or defaults.host + self.port = config.port or defaults.port + end + end + + return setmetatable(self, storage) +end + +function storage:connect() + if self.cluster then + return true -- cluster handles this on its own + end + + local ok, err + if self.socket then + ok, err = self.redis:connect(self.socket, self.connect_opts) + else + ok, err = self.redis:connect(self.host, self.port, self.connect_opts) + end + + if not ok then + return nil, err + end + + if self.password and self.redis:get_reused_times() == 0 then + -- usernames are supported only on Redis 6+, so use new AUTH form only when absolutely necessary + if self.username then + ok, err = self.redis:auth(self.username, self.password) + else + ok, err = self.redis:auth(self.password) + end + if not ok then + self.redis:close() + return nil, err + end + end + + if self.database ~= 0 then + ok, err = self.redis:select(self.database) + if not ok then + self.redis:close() + end + end + + return ok, err +end + +function storage:set_keepalive() + if self.cluster then + return true -- cluster handles this on its own + end + + return self.redis:set_keepalive(self.pool_timeout) +end + +function storage:key(id) + return concat({ self.prefix, id }, ":" ) +end + +function storage:lock(key) + if not self.uselocking or self.locked then + return true + end + + if not self.token then + self.token = var.request_id + end + + local lock_key = concat({ key, "lock" }, "." ) + local lock_ttl = self.maxlockwait + 1 + local attempts = (1000 / self.spinlockwait) * self.maxlockwait + local waittime = self.spinlockwait / 1000 + + for _ = 1, attempts do + local ok = self.redis:set(lock_key, self.token, "EX", lock_ttl, "NX") + if ok ~= null then + self.locked = true + return true + end + + sleep(waittime) + end + + return false, "unable to acquire a session lock" +end + +function storage:unlock(key) + if not self.uselocking or not self.locked then + return + end + + local lock_key = concat({ key, "lock" }, "." ) + + self.redis:eval(UNLOCK, 1, lock_key, self.token) + self.locked = nil +end + +function storage:get(key) + local data, err = self.redis:get(key) + if not data then + return nil, err + end + + if data == null then + return nil + end + + return data +end + +function storage:set(key, data, lifetime) + return self.redis:setex(key, lifetime, data) +end + +function storage:expire(key, lifetime) + return self.redis:expire(key, lifetime) +end + +function storage:delete(key) + return self.redis:del(key) +end + +function storage:open(id, keep_lock) + local ok, err = self:connect() + if not ok then + return nil, err + end + + local key = self:key(id) + + ok, err = self:lock(key) + if not ok then + self:set_keepalive() + return nil, err + end + + local data + data, err = self:get(key) + + if err or not data or not keep_lock then + self:unlock(key) + end + self:set_keepalive() + + return data, err +end + +function storage:start(id) + if not self.uselocking or not self.locked then + return true + end + + local ok, err = self:connect() + if not ok then + return nil, err + end + + ok, err = self:lock(self:key(id)) + + self:set_keepalive() + + return ok, err +end + +function storage:save(id, ttl, data, close) + local ok, err = self:connect() + if not ok then + return nil, err + end + + local key = self:key(id) + + ok, err = self:set(key, data, ttl) + + if close then + self:unlock(key) + end + + self:set_keepalive() + + if not ok then + return nil, err + end + + return true +end + +function storage:close(id) + if not self.uselocking or not self.locked then + return true + end + + local ok, err = self:connect() + if not ok then + return nil, err + end + + local key = self:key(id) + + self:unlock(key) + self:set_keepalive() + + return true +end + +function storage:destroy(id) + local ok, err = self:connect() + if not ok then + return nil, err + end + + local key = self:key(id) + + ok, err = self:delete(key) + + self:unlock(key) + self:set_keepalive() + + return ok, err +end + +function storage:ttl(id, ttl, close) + local ok, err = self:connect() + if not ok then + return nil, err + end + + local key = self:key(id) + + ok, err = self:expire(key, ttl) + + if close then + self:unlock(key) + end + + self:set_keepalive() + + return ok, err +end + +return storage diff --git a/server/resty/session/storage/shm.lua b/server/resty/session/storage/shm.lua new file mode 100644 index 0000000..6f81435 --- /dev/null +++ b/server/resty/session/storage/shm.lua @@ -0,0 +1,125 @@ +local lock = require "resty.lock" + +local setmetatable = setmetatable +local tonumber = tonumber +local concat = table.concat +local var = ngx.var +local shared = ngx.shared + +local function enabled(value) + if value == nil then return nil end + return value == true or (value == "1" or value == "true" or value == "on") +end + +local function ifnil(value, default) + if value == nil then + return default + end + + return enabled(value) +end + +local defaults = { + store = var.session_shm_store or "sessions", + uselocking = enabled(var.session_shm_uselocking or true), + lock = { + exptime = tonumber(var.session_shm_lock_exptime, 10) or 30, + timeout = tonumber(var.session_shm_lock_timeout, 10) or 5, + step = tonumber(var.session_shm_lock_step, 10) or 0.001, + ratio = tonumber(var.session_shm_lock_ratio, 10) or 2, + max_step = tonumber(var.session_shm_lock_max_step, 10) or 0.5, + } +} + +local storage = {} + +storage.__index = storage + +function storage.new(session) + local config = session.shm or defaults + local store = config.store or defaults.store + local locking = ifnil(config.uselocking, defaults.uselocking) + + local self = { + store = shared[store], + uselocking = locking, + } + + if locking then + local lock_opts = config.lock or defaults.lock + local opts = { + exptime = tonumber(lock_opts.exptime, 10) or defaults.exptime, + timeout = tonumber(lock_opts.timeout, 10) or defaults.timeout, + step = tonumber(lock_opts.step, 10) or defaults.step, + ratio = tonumber(lock_opts.ratio, 10) or defaults.ratio, + max_step = tonumber(lock_opts.max_step, 10) or defaults.max_step, + } + self.lock = lock:new(store, opts) + end + + return setmetatable(self, storage) +end + +function storage:open(id, keep_lock) + if self.uselocking then + local ok, err = self.lock:lock(concat{ id, ".lock" }) + if not ok then + return nil, err + end + end + + local data, err = self.store:get(id) + + if self.uselocking and (err or not data or not keep_lock) then + self.lock:unlock() + end + + return data, err +end + +function storage:start(id) + if self.uselocking then + return self.lock:lock(concat{ id, ".lock" }) + end + + return true +end + +function storage:save(id, ttl, data, close) + local ok, err = self.store:set(id, data, ttl) + if close and self.uselocking then + self.lock:unlock() + end + + return ok, err +end + +function storage:close() + if self.uselocking then + self.lock:unlock() + end + + return true +end + +function storage:destroy(id) + self.store:delete(id) + + if self.uselocking then + self.lock:unlock() + end + + return true +end + +function storage:ttl(id, lifetime, close) + local ok, err = self.store:expire(id, lifetime) + + if close and self.uselocking then + self.lock:unlock() + end + + return ok, err +end + +return storage diff --git a/server/resty/session/strategies/default.lua b/server/resty/session/strategies/default.lua new file mode 100644 index 0000000..a43ef5a --- /dev/null +++ b/server/resty/session/strategies/default.lua @@ -0,0 +1,232 @@ +local type = type +local concat = table.concat + +local strategy = {} + +function strategy.load(session, cookie, key, keep_lock) + local storage = session.storage + local id = cookie.id + local id_encoded = session.encoder.encode(id) + + local data, err, tag + if storage.open then + data, err = storage:open(id_encoded, keep_lock) + if not data then + return nil, err or "cookie data was not found" + end + + else + data = cookie.data + end + + local expires = cookie.expires + local usebefore = cookie.usebefore + local hash = cookie.hash + + if not key then + key = concat{ id, expires, usebefore } + end + + local hkey = session.hmac(session.secret, key) + + data, err, tag = session.cipher:decrypt(data, hkey, id, session.key, hash) + if not data then + if storage.close then + storage:close(id_encoded) + end + + return nil, err or "unable to decrypt data" + end + + if tag then + if tag ~= hash then + if storage.close then + storage:close(id_encoded) + end + + return nil, "cookie has invalid tag" + end + + else + local input = concat{ key, data, session.key } + if session.hmac(hkey, input) ~= hash then + if storage.close then + storage:close(id_encoded) + end + + return nil, "cookie has invalid signature" + end + end + + data, err = session.compressor:decompress(data) + if not data then + if storage.close then + storage:close(id_encoded) + end + + return nil, err or "unable to decompress data" + end + + data, err = session.serializer.deserialize(data) + if not data then + if storage.close then + storage:close(id_encoded) + end + + return nil, err or "unable to deserialize data" + end + + session.id = id + session.expires = expires + session.usebefore = usebefore + session.data = type(data) == "table" and data or {} + session.present = true + + return true +end + +function strategy.open(session, cookie, keep_lock) + return strategy.load(session, cookie, nil, keep_lock) +end + +function strategy.start(session) + local storage = session.storage + if not storage.start then + return true + end + + local id_encoded = session.encoder.encode(session.id) + + local ok, err = storage:start(id_encoded) + if not ok then + return nil, err or "unable to start session" + end + + return true +end + +function strategy.modify(session, action, close, key) + local id = session.id + local id_encoded = session.encoder.encode(id) + local storage = session.storage + local expires = session.expires + local usebefore = session.usebefore + local ttl = expires - session.now + + if ttl <= 0 then + if storage.close then + storage:close(id_encoded) + end + + return nil, "session is already expired" + end + + if not key then + key = concat{ id, expires, usebefore } + end + + local data, err = session.serializer.serialize(session.data) + if not data then + if close and storage.close then + storage:close(id_encoded) + end + + return nil, err or "unable to serialize data" + end + + data, err = session.compressor:compress(data) + if not data then + if close and storage.close then + storage:close(id_encoded) + end + + return nil, err or "unable to compress data" + end + + local hkey = session.hmac(session.secret, key) + + local encrypted_data, tag + encrypted_data, err, tag = session.cipher:encrypt(data, hkey, id, session.key) + if not encrypted_data then + if close and storage.close then + storage:close(id_encoded) + end + + return nil, err + end + + local hash + if tag then + hash = tag + else + -- it would be better to calculate signature from encrypted_data, + -- but this is kept for backward compatibility + hash = session.hmac(hkey, concat{ key, data, session.key }) + end + + if action == "save" and storage.save then + local ok + ok, err = storage:save(id_encoded, ttl, encrypted_data, close) + if not ok then + return nil, err + end + elseif close and storage.close then + local ok + ok, err = storage:close(id_encoded) + if not ok then + return nil, err + end + end + + if usebefore then + expires = expires .. ":" .. usebefore + end + + hash = session.encoder.encode(hash) + + local cookie + if storage.save then + cookie = concat({ id_encoded, expires, hash }, "|") + else + local encoded_data = session.encoder.encode(encrypted_data) + cookie = concat({ id_encoded, expires, encoded_data, hash }, "|") + end + + return cookie +end + +function strategy.touch(session, close) + return strategy.modify(session, "touch", close) +end + +function strategy.save(session, close) + return strategy.modify(session, "save", close) +end + +function strategy.destroy(session) + local id = session.id + if id then + local storage = session.storage + if storage.destroy then + return storage:destroy(session.encoder.encode(id)) + elseif storage.close then + return storage:close(session.encoder.encode(id)) + end + end + + return true +end + +function strategy.close(session) + local id = session.id + if id then + local storage = session.storage + if storage.close then + return storage:close(session.encoder.encode(id)) + end + end + + return true +end + +return strategy diff --git a/server/resty/session/strategies/regenerate.lua b/server/resty/session/strategies/regenerate.lua new file mode 100644 index 0000000..f2a97dd --- /dev/null +++ b/server/resty/session/strategies/regenerate.lua @@ -0,0 +1,43 @@ +local default = require "resty.session.strategies.default" + +local concat = table.concat + +local strategy = { + regenerate = true, + start = default.start, + destroy = default.destroy, + close = default.close, +} + +local function key(source) + if source.usebefore then + return concat{ source.id, source.usebefore } + end + + return source.id +end + +function strategy.open(session, cookie, keep_lock) + return default.load(session, cookie, key(cookie), keep_lock) +end + +function strategy.touch(session, close) + return default.modify(session, "touch", close, key(session)) +end + +function strategy.save(session, close) + if session.present then + local storage = session.storage + if storage.ttl then + storage:ttl(session.encoder.encode(session.id), session.cookie.discard, true) + elseif storage.close then + storage:close(session.encoder.encode(session.id)) + end + + session.id = session:identifier() + end + + return default.modify(session, "save", close, key(session)) +end + +return strategy |