aboutsummaryrefslogtreecommitdiffstats
path: root/server/resty/jwt.lua
diff options
context:
space:
mode:
Diffstat (limited to 'server/resty/jwt.lua')
-rw-r--r--server/resty/jwt.lua959
1 files changed, 959 insertions, 0 deletions
diff --git a/server/resty/jwt.lua b/server/resty/jwt.lua
new file mode 100644
index 0000000..accba11
--- /dev/null
+++ b/server/resty/jwt.lua
@@ -0,0 +1,959 @@
+local cjson = require "cjson.safe"
+
+local evp = require "resty.evp"
+local hmac = require "resty.hmac"
+local resty_random = require "resty.random"
+local cipher = require "resty.openssl.cipher"
+
+local _M = { _VERSION = "0.2.3" }
+
+local mt = {
+ __index = _M
+}
+
+local string_rep = string.rep
+local string_format = string.format
+local string_sub = string.sub
+local string_char = string.char
+local table_concat = table.concat
+local ngx_encode_base64 = ngx.encode_base64
+local ngx_decode_base64 = ngx.decode_base64
+local cjson_encode = cjson.encode
+local cjson_decode = cjson.decode
+local tostring = tostring
+local error = error
+local ipairs = ipairs
+local type = type
+local pcall = pcall
+local assert = assert
+local setmetatable = setmetatable
+local pairs = pairs
+
+-- define string constants to avoid string garbage collection
+local str_const = {
+ invalid_jwt= "invalid jwt string",
+ regex_join_msg = "%s.%s",
+ regex_join_delim = "([^%s]+)",
+ regex_split_dot = "%.",
+ regex_jwt_join_str = "%s.%s.%s",
+ raw_underscore = "raw_",
+ dash = "-",
+ empty = "",
+ dotdot = "..",
+ table = "table",
+ plus = "+",
+ equal = "=",
+ underscore = "_",
+ slash = "/",
+ header = "header",
+ typ = "typ",
+ JWT = "JWT",
+ JWE = "JWE",
+ payload = "payload",
+ signature = "signature",
+ encrypted_key = "encrypted_key",
+ alg = "alg",
+ enc = "enc",
+ kid = "kid",
+ exp = "exp",
+ nbf = "nbf",
+ iss = "iss",
+ full_obj = "__jwt",
+ x5c = "x5c",
+ x5u = 'x5u',
+ HS256 = "HS256",
+ HS512 = "HS512",
+ RS256 = "RS256",
+ ES256 = "ES256",
+ ES512 = "ES512",
+ RS512 = "RS512",
+ A128CBC_HS256 = "A128CBC-HS256",
+ A128CBC_HS256_CIPHER_MODE = "aes-128-cbc",
+ A256CBC_HS512 = "A256CBC-HS512",
+ A256CBC_HS512_CIPHER_MODE = "aes-256-cbc",
+ A256GCM = "A256GCM",
+ A256GCM_CIPHER_MODE = "aes-256-gcm",
+ RSA_OAEP_256 = "RSA-OAEP-256",
+ DIR = "dir",
+ reason = "reason",
+ verified = "verified",
+ number = "number",
+ string = "string",
+ funct = "function",
+ boolean = "boolean",
+ valid = "valid",
+ valid_issuers = "valid_issuers",
+ lifetime_grace_period = "lifetime_grace_period",
+ require_nbf_claim = "require_nbf_claim",
+ require_exp_claim = "require_exp_claim",
+ internal_error = "internal error",
+ everything_awesome = "everything is awesome~ :p"
+}
+
+-- @function split string
+local function split_string(str, delim)
+ local result = {}
+ local sep = string_format(str_const.regex_join_delim, delim)
+ for m in str:gmatch(sep) do
+ result[#result+1]=m
+ end
+ return result
+end
+
+-- @function is nil or boolean
+-- @return true if param is nil or true or false; false otherwise
+local function is_nil_or_boolean(arg_value)
+ if arg_value == nil then
+ return true
+ end
+
+ if type(arg_value) ~= str_const.boolean then
+ return false
+ end
+
+ return true
+end
+
+--@function get the raw part
+--@param part_name
+--@param jwt_obj
+local function get_raw_part(part_name, jwt_obj)
+ local raw_part = jwt_obj[str_const.raw_underscore .. part_name]
+ if raw_part == nil then
+ local part = jwt_obj[part_name]
+ if part == nil then
+ error({reason="missing part " .. part_name})
+ end
+ raw_part = _M:jwt_encode(part)
+ end
+ return raw_part
+end
+
+
+--@function decrypt payload
+--@param secret_key to decrypt the payload
+--@param encrypted payload
+--@param encryption algorithm
+--@param iv which was generated while encrypting the payload
+--@param aad additional authenticated data (used when gcm mode is used)
+--@param auth_tag authenticated tag (used when gcm mode is used)
+--@return decrypted payloaf
+local function decrypt_payload(secret_key, encrypted_payload, enc, iv_in, aad, auth_tag )
+ local decrypted_payload, err
+ if enc == str_const.A128CBC_HS256 then
+ local aes_128_cbs_cipher = assert(cipher.new(str_const.A128CBC_HS256_CIPHER_MODE))
+ decrypted_payload, err= aes_128_cbs_cipher:decrypt(secret_key, iv_in, encrypted_payload)
+ elseif enc == str_const.A256CBC_HS512 then
+ local aes_256_cbs_cipher = assert(cipher.new(str_const.A256CBC_HS512_CIPHER_MODE))
+ decrypted_payload, err = aes_256_cbs_cipher:decrypt(secret_key, iv_in, encrypted_payload)
+ elseif enc == str_const.A256GCM then
+ local aes_256_gcm_cipher = assert(cipher.new(str_const.A256GCM_CIPHER_MODE))
+ decrypted_payload, err = aes_256_gcm_cipher:decrypt(secret_key, iv_in, encrypted_payload, false, aad, auth_tag)
+ else
+ return nil, "unsupported enc: " .. enc
+ end
+ if not decrypted_payload or err then
+ return nil, err
+ end
+ return decrypted_payload
+end
+
+-- @function encrypt payload using given secret
+-- @param secret_key secret key to encrypt
+-- @param message data to be encrypted. It could be lua table or string
+-- @param enc algorithm to use for encryption
+-- @param aad additional authenticated data (used when gcm mode is used)
+local function encrypt_payload(secret_key, message, enc, aad )
+
+ if enc == str_const.A128CBC_HS256 then
+ local iv_rand = resty_random.bytes(16,true)
+ local aes_128_cbs_cipher = assert(cipher.new(str_const.A128CBC_HS256_CIPHER_MODE))
+ local encrypted = aes_128_cbs_cipher:encrypt(secret_key, iv_rand, message)
+ return encrypted, iv_rand
+
+ elseif enc == str_const.A256CBC_HS512 then
+ local iv_rand = resty_random.bytes(16,true)
+ local aes_256_cbs_cipher = assert(cipher.new(str_const.A256CBC_HS512_CIPHER_MODE))
+ local encrypted = aes_256_cbs_cipher:encrypt(secret_key, iv_rand, message)
+ return encrypted, iv_rand
+
+ elseif enc == str_const.A256GCM then
+ local iv_rand = resty_random.bytes(12,true) -- 96 bit IV is recommended for efficiency
+ local aes_256_gcm_cipher = assert(cipher.new(str_const.A256GCM_CIPHER_MODE))
+ local encrypted = aes_256_gcm_cipher:encrypt(secret_key, iv_rand, message, false, aad)
+ local auth_tag = assert(aes_256_gcm_cipher:get_aead_tag())
+ return encrypted, iv_rand, auth_tag
+
+ else
+ return nil, nil , nil, "unsupported enc: " .. enc
+ end
+end
+
+--@function hmac_digest : generate hmac digest based on key for input message
+--@param mac_key
+--@param input message
+--@return hmac digest
+local function hmac_digest(enc, mac_key, message)
+ if enc == str_const.A128CBC_HS256 then
+ return hmac:new(mac_key, hmac.ALGOS.SHA256):final(message)
+ elseif enc == str_const.A256CBC_HS512 then
+ return hmac:new(mac_key, hmac.ALGOS.SHA512):final(message)
+ else
+ error({reason="unsupported enc: " .. enc})
+ end
+end
+
+--@function dervice keys: it generates key if null based on encryption algorithm
+--@param encryption type
+--@param secret key
+--@return secret key, mac key and encryption key
+local function derive_keys(enc, secret_key)
+ local mac_key_len, enc_key_len = 16, 16
+
+ if enc == str_const.A256GCM then
+ mac_key_len, enc_key_len = 0, 32 -- we need 256 bit key
+ elseif enc == str_const.A128CBC_HS256 then
+ mac_key_len, enc_key_len = 16, 16
+ elseif enc == str_const.A256CBC_HS512 then
+ mac_key_len, enc_key_len = 32, 32
+ else
+ error({reason="unsupported payload encryption algorithm :" .. enc})
+ end
+
+ local secret_key_len = mac_key_len + enc_key_len
+
+ if not secret_key then
+ secret_key = resty_random.bytes(secret_key_len, true)
+ end
+
+ if #secret_key ~= secret_key_len then
+ error({reason="invalid pre-shared key"})
+ end
+
+ local mac_key = string_sub(secret_key, 1, mac_key_len)
+ local enc_key = string_sub(secret_key, mac_key_len + 1)
+ return secret_key, mac_key, enc_key
+end
+
+local function get_payload_encoder(self)
+ return self.payload_encoder or cjson_encode
+end
+
+local function get_payload_decoder(self)
+ return self.payload_decoder or cjson_decode
+end
+
+--@function parse_jwe
+--@param pre-shared key
+--@encoded-header
+local function parse_jwe(self, preshared_key, encoded_header, encoded_encrypted_key, encoded_iv, encoded_cipher_text, encoded_auth_tag)
+
+
+ local header = _M:jwt_decode(encoded_header, true)
+ if not header then
+ error({reason="invalid header: " .. encoded_header})
+ end
+
+ local alg = header.alg
+ if alg ~= str_const.DIR and alg ~= str_const.RSA_OAEP_256 then
+ error({reason="invalid algorithm: " .. alg})
+ end
+
+ local key, enc_key
+ if alg == str_const.DIR then
+ if not preshared_key then
+ error({reason="preshared key must not be null"})
+ end
+ key, _, enc_key = derive_keys(header.enc, preshared_key)
+ elseif alg == str_const.RSA_OAEP_256 then
+ if not preshared_key then
+ error({reason="rsa private key must not be null"})
+ end
+ local rsa_decryptor, err = evp.RSADecryptor:new(preshared_key, nil, evp.CONST.RSA_PKCS1_OAEP_PADDING, evp.CONST.SHA256_DIGEST)
+ if err then
+ error({reason="failed to create rsa object: ".. err})
+ end
+ local secret_key, err = rsa_decryptor:decrypt(_M:jwt_decode(encoded_encrypted_key))
+ if err or not secret_key then
+ error({reason="failed to decrypt key: " .. err})
+ end
+ key, _, enc_key = derive_keys(header.enc, secret_key)
+ end
+
+ local cipher_text = _M:jwt_decode(encoded_cipher_text)
+ local iv = _M:jwt_decode(encoded_iv)
+ local signature_or_tag = _M:jwt_decode(encoded_auth_tag)
+ local basic_jwe = {
+ internal = {
+ encoded_header = encoded_header,
+ cipher_text = cipher_text,
+ key = key,
+ iv = iv
+ },
+ header = header,
+ signature = signature_or_tag
+ }
+
+ local payload, err = decrypt_payload(enc_key, cipher_text, header.enc, iv, encoded_header, signature_or_tag)
+ if err then
+ error({reason="failed to decrypt payload: " .. err})
+
+ else
+ basic_jwe.payload = get_payload_decoder(self)(payload)
+ basic_jwe.internal.json_payload=payload
+ end
+ return basic_jwe
+end
+
+-- @function parse_jwt
+-- @param encoded header
+-- @param encoded
+-- @param signature
+-- @return jwt table
+local function parse_jwt(encoded_header, encoded_payload, signature)
+ local header = _M:jwt_decode(encoded_header, true)
+ if not header then
+ error({reason="invalid header: " .. encoded_header})
+ end
+
+ local payload = _M:jwt_decode(encoded_payload, true)
+ if not payload then
+ error({reason="invalid payload: " .. encoded_payload})
+ end
+
+ local basic_jwt = {
+ raw_header=encoded_header,
+ raw_payload=encoded_payload,
+ header=header,
+ payload=payload,
+ signature=signature
+ }
+ return basic_jwt
+
+end
+
+-- @function parse token - this can be JWE or JWT token
+-- @param token string
+-- @return jwt/jwe tables
+local function parse(self, secret, token_str)
+ local tokens = split_string(token_str, str_const.regex_split_dot)
+ local num_tokens = #tokens
+ if num_tokens == 3 then
+ return parse_jwt(tokens[1], tokens[2], tokens[3])
+ elseif num_tokens == 4 then
+ return parse_jwe(self, secret, tokens[1], nil, tokens[2], tokens[3], tokens[4])
+ elseif num_tokens == 5 then
+ return parse_jwe(self, secret, tokens[1], tokens[2], tokens[3], tokens[4], tokens[5])
+ else
+ error({reason=str_const.invalid_jwt})
+ end
+end
+
+--@function jwt encode : it converts into base64 encoded string. if input is a table, it convets into
+-- json before converting to base64 string
+--@param payloaf
+--@return base64 encoded payloaf
+function _M.jwt_encode(self, ori, is_payload)
+ if type(ori) == str_const.table then
+ ori = is_payload and get_payload_encoder(self)(ori) or cjson_encode(ori)
+ end
+ local res = ngx_encode_base64(ori):gsub(str_const.plus, str_const.dash):gsub(str_const.slash, str_const.underscore):gsub(str_const.equal, str_const.empty)
+ return res
+end
+
+
+
+--@function jwt decode : decode bas64 encoded string
+function _M.jwt_decode(self, b64_str, json_decode, is_payload)
+ b64_str = b64_str:gsub(str_const.dash, str_const.plus):gsub(str_const.underscore, str_const.slash)
+
+ local reminder = #b64_str % 4
+ if reminder > 0 then
+ b64_str = b64_str .. string_rep(str_const.equal, 4 - reminder)
+ end
+ local data = ngx_decode_base64(b64_str)
+ if not data then
+ return nil
+ end
+ if json_decode then
+ data = is_payload and get_payload_decoder(self)(data) or cjson_decode(data)
+ end
+ return data
+end
+
+--- Initialize the trusted certs
+-- During RS256 verify, we'll make sure the
+-- cert was signed by one of these
+function _M.set_trusted_certs_file(self, filename)
+ self.trusted_certs_file = filename
+end
+_M.trusted_certs_file = nil
+
+--- Set a whitelist of allowed algorithms
+-- E.g., jwt:set_alg_whitelist({RS256=1,HS256=1})
+--
+-- @param algorithms - A table with keys for the supported algorithms
+-- If the table is non-nil, during
+-- verify, the alg must be in the table
+function _M.set_alg_whitelist(self, algorithms)
+ self.alg_whitelist = algorithms
+end
+
+_M.alg_whitelist = nil
+
+
+--- Returns the list of default validations that will be
+--- applied upon the verification of a jwt.
+function _M.get_default_validation_options(self, jwt_obj)
+ return {
+ [str_const.require_exp_claim]=jwt_obj[str_const.payload].exp ~= nil,
+ [str_const.require_nbf_claim]=jwt_obj[str_const.payload].nbf ~= nil
+ }
+end
+
+--- Set a function used to retrieve the content of x5u urls
+--
+-- @param retriever_function - A pointer to a function. This function should be
+-- defined to accept three string parameters. First one
+-- will be the value of the 'x5u' attribute. Second
+-- one will be the value of the 'iss' attribute, would
+-- it be defined in the jwt. Third one will be the value
+-- of the 'iss' attribute, would it be defined in the jwt.
+-- This function should return the matching certificate.
+function _M.set_x5u_content_retriever(self, retriever_function)
+ if type(retriever_function) ~= str_const.funct then
+ error("'retriever_function' is expected to be a function", 0)
+ end
+ self.x5u_content_retriever = retriever_function
+end
+
+_M.x5u_content_retriever = nil
+
+-- https://tools.ietf.org/html/rfc7516#appendix-B.3
+-- TODO: do it in lua way
+local function binlen(s)
+ if type(s) ~= 'string' then return end
+
+ local len = 8 * #s
+
+ return string_char(len / 0x0100000000000000 % 0x100)
+ .. string_char(len / 0x0001000000000000 % 0x100)
+ .. string_char(len / 0x0000010000000000 % 0x100)
+ .. string_char(len / 0x0000000100000000 % 0x100)
+ .. string_char(len / 0x0000000001000000 % 0x100)
+ .. string_char(len / 0x0000000000010000 % 0x100)
+ .. string_char(len / 0x0000000000000100 % 0x100)
+ .. string_char(len / 0x0000000000000001 % 0x100)
+end
+
+--@function sign jwe payload
+--@param secret key : if used pre-shared or RSA key
+--@param jwe payload
+--@return jwe token
+local function sign_jwe(self, secret_key, jwt_obj)
+ local header = jwt_obj.header
+ local enc = header.enc
+ local alg = header.alg
+
+ -- remove type
+ if header.typ then
+ header.typ = nil
+ end
+
+ -- TODO: implement logic for creating enc key and mac key and then encrypt key
+ local key, encrypted_key, mac_key, enc_key
+ local encoded_header = _M:jwt_encode(header)
+ local payload_to_encrypt = get_payload_encoder(self)(jwt_obj.payload)
+ if alg == str_const.DIR then
+ _, mac_key, enc_key = derive_keys(enc, secret_key)
+ encrypted_key = ""
+ elseif alg == str_const.RSA_OAEP_256 then
+ local cert, err
+ if secret_key:find("CERTIFICATE") then
+ cert, err = evp.Cert:new(secret_key)
+ elseif secret_key:find("PUBLIC KEY") then
+ cert, err = evp.PublicKey:new(secret_key)
+ end
+ if not cert then
+ error({reason="Decode secret is not a valid cert/public key: " .. (err and err or secret_key)})
+ end
+ local rsa_encryptor = evp.RSAEncryptor:new(cert, evp.CONST.RSA_PKCS1_OAEP_PADDING, evp.CONST.SHA256_DIGEST)
+ if err then
+ error("failed to create rsa object for encryption ".. err)
+ end
+ key, mac_key, enc_key = derive_keys(enc)
+ encrypted_key, err = rsa_encryptor:encrypt(key)
+ if err or not encrypted_key then
+ error({reason="failed to encrypt key " .. (err or "")})
+ end
+ else
+ error({reason="unsupported alg: " .. alg})
+ end
+
+ local cipher_text, iv, auth_tag, err = encrypt_payload(enc_key, payload_to_encrypt, enc, encoded_header)
+ if err then
+ error({reason="error while encrypting payload. Error: " .. err})
+ end
+
+ if not auth_tag then
+ local encoded_header_length = binlen(encoded_header)
+ local mac_input = table_concat({encoded_header , iv, cipher_text , encoded_header_length})
+ local mac = hmac_digest(enc, mac_key, mac_input)
+ auth_tag = string_sub(mac, 1, #mac/2)
+ end
+
+ local jwe_table = {encoded_header, _M:jwt_encode(encrypted_key), _M:jwt_encode(iv),
+ _M:jwt_encode(cipher_text), _M:jwt_encode(auth_tag)}
+ return table_concat(jwe_table, ".", 1, 5)
+end
+
+--@function get_secret_str : returns the secret if it is a string, or the result of a function
+--@param either the string secret or a function that takes a string parameter and returns a string or nil
+--@param jwt payload
+--@return the secret as a string or as a function
+local function get_secret_str(secret_or_function, jwt_obj)
+ if type(secret_or_function) == str_const.funct then
+ -- Only use with hmac algorithms
+ local alg = jwt_obj[str_const.header][str_const.alg]
+ if alg ~= str_const.HS256 and alg ~= str_const.HS512 then
+ error({reason="secret function can only be used with hmac alg: " .. alg})
+ end
+
+ -- Pull out the kid value from the header
+ local kid_val = jwt_obj[str_const.header][str_const.kid]
+ if kid_val == nil then
+ error({reason="secret function specified without kid in header"})
+ end
+
+ -- Call the function
+ return secret_or_function(kid_val) or error({reason="function returned nil for kid: " .. kid_val})
+ elseif type(secret_or_function) == str_const.string then
+ -- Just return the string
+ return secret_or_function
+ else
+ -- Throw an error
+ error({reason="invalid secret type (must be string or function)"})
+ end
+end
+
+--@function sign : create a jwt/jwe signature from jwt_object
+--@param secret key
+--@param jwt/jwe payload
+function _M.sign(self, secret_key, jwt_obj)
+ -- header typ check
+ local typ = jwt_obj[str_const.header][str_const.typ]
+ -- Optional header typ check [See http://tools.ietf.org/html/draft-ietf-oauth-json-web-token-25#section-5.1]
+ if typ ~= nil then
+ if typ ~= str_const.JWT and typ ~= str_const.JWE then
+ error({reason="invalid typ: " .. typ})
+ end
+ end
+
+ if typ == str_const.JWE or jwt_obj.header.enc then
+ return sign_jwe(self, secret_key, jwt_obj)
+ end
+ -- header alg check
+ local raw_header = get_raw_part(str_const.header, jwt_obj)
+ local raw_payload = get_raw_part(str_const.payload, jwt_obj)
+ local message = string_format(str_const.regex_join_msg, raw_header, raw_payload)
+ local alg = jwt_obj[str_const.header][str_const.alg]
+ local signature = ""
+ if alg == str_const.HS256 then
+ local secret_str = get_secret_str(secret_key, jwt_obj)
+ signature = hmac:new(secret_str, hmac.ALGOS.SHA256):final(message)
+ elseif alg == str_const.HS512 then
+ local secret_str = get_secret_str(secret_key, jwt_obj)
+ signature = hmac:new(secret_str, hmac.ALGOS.SHA512):final(message)
+ elseif alg == str_const.RS256 or alg == str_const.RS512 then
+ local signer, err = evp.RSASigner:new(secret_key)
+ if not signer then
+ error({reason="signer error: " .. err})
+ end
+ if alg == str_const.RS256 then
+ signature = signer:sign(message, evp.CONST.SHA256_DIGEST)
+ elseif alg == str_const.RS512 then
+ signature = signer:sign(message, evp.CONST.SHA512_DIGEST)
+ end
+ elseif alg == str_const.ES256 or alg == str_const.ES512 then
+ local signer, err = evp.ECSigner:new(secret_key)
+ if not signer then
+ error({reason="signer error: " .. err})
+ end
+ -- OpenSSL will generate a DER encoded signature that needs to be converted
+ local der_signature = ""
+ if alg == str_const.ES256 then
+ der_signature = signer:sign(message, evp.CONST.SHA256_DIGEST)
+ elseif alg == str_const.ES512 then
+ der_signature = signer:sign(message, evp.CONST.SHA512_DIGEST)
+ end
+ -- Perform DER to RAW signature conversion
+ signature, err = signer:get_raw_sig(der_signature)
+ if not signature then
+ error({reason="signature error: " .. err})
+ end
+ else
+ error({reason="unsupported alg: " .. alg})
+ end
+ -- return full jwt string
+ return string_format(str_const.regex_join_msg, message , _M:jwt_encode(signature))
+
+end
+
+--@function load jwt
+--@param jwt string token
+--@param secret
+function _M.load_jwt(self, jwt_str, secret)
+ local success, ret = pcall(parse, self, secret, jwt_str)
+ if not success then
+ return {
+ valid=false,
+ verified=false,
+ reason=ret[str_const.reason] or str_const.invalid_jwt
+ }
+ end
+
+ local jwt_obj = ret
+ jwt_obj[str_const.verified] = false
+ jwt_obj[str_const.valid] = true
+ return jwt_obj
+end
+
+--@function verify jwe object
+--@param jwt object
+--@return jwt object with reason whether verified or not
+local function verify_jwe_obj(jwt_obj)
+
+ if jwt_obj[str_const.header][str_const.enc] ~= str_const.A256GCM then -- tag gets authenticated during decryption
+ local _, mac_key, _ = derive_keys(jwt_obj.header.enc, jwt_obj.internal.key)
+ local encoded_header = jwt_obj.internal.encoded_header
+
+ local encoded_header_length = binlen(encoded_header)
+ local mac_input = table_concat({encoded_header , jwt_obj.internal.iv, jwt_obj.internal.cipher_text,
+ encoded_header_length})
+ local mac = hmac_digest(jwt_obj.header.enc, mac_key, mac_input)
+ local auth_tag = string_sub(mac, 1, #mac/2)
+
+ if auth_tag ~= jwt_obj.signature then
+ jwt_obj[str_const.reason] = "signature mismatch: " ..
+ tostring(jwt_obj[str_const.signature])
+ end
+ end
+
+ jwt_obj.internal = nil
+ jwt_obj.signature = nil
+
+ if not jwt_obj[str_const.reason] then
+ jwt_obj[str_const.verified] = true
+ jwt_obj[str_const.reason] = str_const.everything_awesome
+ end
+
+ return jwt_obj
+end
+
+--@function extract certificate
+--@param jwt object
+--@return decoded certificate
+local function extract_certificate(jwt_obj, x5u_content_retriever)
+ local x5c = jwt_obj[str_const.header][str_const.x5c]
+ if x5c ~= nil and x5c[1] ~= nil then
+ -- TODO Might want to add support for intermediaries that we
+ -- don't have in our trusted chain (items 2... if present)
+
+ local cert_str = ngx_decode_base64(x5c[1])
+ if not cert_str then
+ jwt_obj[str_const.reason] = "Malformed x5c header"
+ end
+
+ return cert_str
+ end
+
+ local x5u = jwt_obj[str_const.header][str_const.x5u]
+ if x5u ~= nil then
+ -- TODO Ensure the url starts with https://
+ -- cf. https://tools.ietf.org/html/rfc7517#section-4.6
+
+ if x5u_content_retriever == nil then
+ jwt_obj[str_const.reason] = "No function has been provided to retrieve the content pointed at by the 'x5u'."
+ return nil
+ end
+
+ -- TODO Maybe validate the url against an optional list whitelisted url prefixes?
+ -- cf. https://news.ycombinator.com/item?id=9302394
+
+ local iss = jwt_obj[str_const.payload][str_const.iss]
+ local kid = jwt_obj[str_const.header][str_const.kid]
+ local success, ret = pcall(x5u_content_retriever, x5u, iss, kid)
+
+ if not success then
+ jwt_obj[str_const.reason] = "An error occured while invoking the x5u_content_retriever function."
+ return nil
+ end
+
+ return ret
+ end
+
+ -- TODO When both x5c and x5u are defined, the implementation should
+ -- ensure their content match
+ -- cf. https://tools.ietf.org/html/rfc7517#section-4.6
+
+ jwt_obj[str_const.reason] = "Unsupported RS256 key model"
+ return nil
+ -- TODO - Implement jwk and kid based models...
+end
+
+local function get_claim_spec_from_legacy_options(self, options)
+ local claim_spec = { }
+ local jwt_validators = require "resty.jwt-validators"
+
+ if options[str_const.valid_issuers] ~= nil then
+ claim_spec[str_const.iss] = jwt_validators.equals_any_of(options[str_const.valid_issuers])
+ end
+
+ if options[str_const.lifetime_grace_period] ~= nil then
+ jwt_validators.set_system_leeway(options[str_const.lifetime_grace_period] or 0)
+
+ -- If we have a leeway set, then either an NBF or an EXP should also exist requireds are added below
+ if options[str_const.require_nbf_claim] ~= true and options[str_const.require_exp_claim] ~= true then
+ claim_spec[str_const.full_obj] = jwt_validators.require_one_of({ str_const.nbf, str_const.exp })
+ end
+ end
+
+ if not is_nil_or_boolean(options[str_const.require_nbf_claim]) then
+ error(string.format("'%s' validation option is expected to be a boolean.", str_const.require_nbf_claim), 0)
+ end
+
+ if not is_nil_or_boolean(options[str_const.require_exp_claim]) then
+ error(string.format("'%s' validation option is expected to be a boolean.", str_const.require_exp_claim), 0)
+ end
+
+ if options[str_const.lifetime_grace_period] ~= nil or options[str_const.require_nbf_claim] ~= nil or options[str_const.require_exp_claim] ~= nil then
+ if options[str_const.require_nbf_claim] == true then
+ claim_spec[str_const.nbf] = jwt_validators.is_not_before()
+ else
+ claim_spec[str_const.nbf] = jwt_validators.opt_is_not_before()
+ end
+
+ if options[str_const.require_exp_claim] == true then
+ claim_spec[str_const.exp] = jwt_validators.is_not_expired()
+ else
+ claim_spec[str_const.exp] = jwt_validators.opt_is_not_expired()
+ end
+ end
+
+ return claim_spec
+end
+
+local function is_legacy_validation_options(options)
+
+ -- Validation options MUST be a table
+ if type(options) ~= str_const.table then
+ return false
+ end
+
+ -- Validation options MUST have at least one of these, and must ONLY have these
+ local legacy_options = { }
+ legacy_options[str_const.valid_issuers]=1
+ legacy_options[str_const.lifetime_grace_period]=1
+ legacy_options[str_const.require_nbf_claim]=1
+ legacy_options[str_const.require_exp_claim]=1
+
+ local is_legacy = false
+ for k in pairs(options) do
+ if legacy_options[k] ~= nil then
+ is_legacy = true
+ else
+ return false
+ end
+ end
+ return is_legacy
+end
+
+-- Validates the claims for the given (parsed) object
+local function validate_claims(self, jwt_obj, ...)
+ local claim_specs = {...}
+ if #claim_specs == 0 then
+ table.insert(claim_specs, _M:get_default_validation_options(jwt_obj))
+ end
+
+ if jwt_obj[str_const.reason] ~= nil then
+ return false
+ end
+
+ -- Encode the current jwt_obj and use it when calling the individual validation functions
+ local jwt_json = cjson_encode(jwt_obj)
+
+ -- Validate all our specs
+ for _, claim_spec in ipairs(claim_specs) do
+ if is_legacy_validation_options(claim_spec) then
+ claim_spec = get_claim_spec_from_legacy_options(self, claim_spec)
+ end
+ for claim, fx in pairs(claim_spec) do
+ if type(fx) ~= str_const.funct then
+ error("Claim spec value must be a function - see jwt-validators.lua for helper functions", 0)
+ end
+
+ local val = claim == str_const.full_obj and cjson_decode(jwt_json) or jwt_obj.payload[claim]
+ local success, ret = pcall(fx, val, claim, jwt_json)
+ if not success then
+ jwt_obj[str_const.reason] = ret.reason or string.gsub(ret, "^.-:%d-: ", "")
+ return false
+ elseif ret == false then
+ jwt_obj[str_const.reason] = string.format("Claim '%s' ('%s') returned failure", claim, val)
+ return false
+ end
+ end
+ end
+
+ -- Everything was good
+ return true
+end
+
+--@function verify jwt object
+--@param secret
+--@param jwt_object
+--@leeway
+--@return verified jwt payload or jwt object with error code
+function _M.verify_jwt_obj(self, secret, jwt_obj, ...)
+ if not jwt_obj.valid then
+ return jwt_obj
+ end
+
+ -- validate any claims that have been passed in
+ if not validate_claims(self, jwt_obj, ...) then
+ return jwt_obj
+ end
+
+ -- if jwe, invoked verify jwe
+ if jwt_obj[str_const.header][str_const.enc] then
+ return verify_jwe_obj(jwt_obj)
+ end
+
+ local alg = jwt_obj[str_const.header][str_const.alg]
+
+ local jwt_str = string_format(str_const.regex_jwt_join_str, jwt_obj.raw_header , jwt_obj.raw_payload , jwt_obj.signature)
+
+ if self.alg_whitelist ~= nil then
+ if self.alg_whitelist[alg] == nil then
+ return {verified=false, reason="whitelist unsupported alg: " .. alg}
+ end
+ end
+
+ if alg == str_const.HS256 or alg == str_const.HS512 then
+ local success, ret = pcall(_M.sign, self, secret, jwt_obj)
+ if not success then
+ -- syntax check
+ jwt_obj[str_const.reason] = ret[str_const.reason] or str_const.internal_error
+ elseif jwt_str ~= ret then
+ -- signature check
+ jwt_obj[str_const.reason] = "signature mismatch: " .. jwt_obj[str_const.signature]
+ end
+ elseif alg == str_const.RS256 or alg == str_const.RS512 or alg == str_const.ES256 or alg == str_const.ES512 then
+ local cert, err
+ if self.trusted_certs_file ~= nil then
+ local cert_str = extract_certificate(jwt_obj, self.x5u_content_retriever)
+ if not cert_str then
+ return jwt_obj
+ end
+ cert, err = evp.Cert:new(cert_str)
+ if not cert then
+ jwt_obj[str_const.reason] = "Unable to extract signing cert from JWT: " .. err
+ return jwt_obj
+ end
+ -- Try validating against trusted CA's, then a cert passed as secret
+ local trusted = cert:verify_trust(self.trusted_certs_file)
+ if not trusted then
+ jwt_obj[str_const.reason] = "Cert used to sign the JWT isn't trusted: " .. err
+ return jwt_obj
+ end
+ elseif secret ~= nil then
+ if secret:find("CERTIFICATE") then
+ cert, err = evp.Cert:new(secret)
+ elseif secret:find("PUBLIC KEY") then
+ cert, err = evp.PublicKey:new(secret)
+ end
+ if not cert then
+ jwt_obj[str_const.reason] = "Decode secret is not a valid cert/public key"
+ return jwt_obj
+ end
+ else
+ jwt_obj[str_const.reason] = "No trusted certs loaded"
+ return jwt_obj
+ end
+ local verifier = ''
+ if alg == str_const.RS256 or alg == str_const.RS512 then
+ verifier = evp.RSAVerifier:new(cert)
+ elseif alg == str_const.ES256 or alg == str_const.ES512 then
+ verifier = evp.ECVerifier:new(cert)
+ end
+ if not verifier then
+ -- Internal error case, should not happen...
+ jwt_obj[str_const.reason] = "Failed to build verifier " .. err
+ return jwt_obj
+ end
+
+ -- assemble jwt parts
+ local raw_header = get_raw_part(str_const.header, jwt_obj)
+ local raw_payload = get_raw_part(str_const.payload, jwt_obj)
+
+ local message =string_format(str_const.regex_join_msg, raw_header , raw_payload)
+ local sig = _M:jwt_decode(jwt_obj[str_const.signature], false)
+
+ if not sig then
+ jwt_obj[str_const.reason] = "Wrongly encoded signature"
+ return jwt_obj
+ end
+
+ local verified = false
+ err = "verify error: reason unknown"
+
+ if alg == str_const.RS256 or alg == str_const.ES256 then
+ verified, err = verifier:verify(message, sig, evp.CONST.SHA256_DIGEST)
+ elseif alg == str_const.RS512 or alg == str_const.ES512 then
+ verified, err = verifier:verify(message, sig, evp.CONST.SHA512_DIGEST)
+ end
+ if not verified then
+ jwt_obj[str_const.reason] = err
+ end
+ else
+ jwt_obj[str_const.reason] = "Unsupported algorithm " .. alg
+ end
+
+ if not jwt_obj[str_const.reason] then
+ jwt_obj[str_const.verified] = true
+ jwt_obj[str_const.reason] = str_const.everything_awesome
+ end
+ return jwt_obj
+
+end
+
+
+function _M.verify(self, secret, jwt_str, ...)
+ local jwt_obj = _M.load_jwt(self, jwt_str, secret)
+ if not jwt_obj.valid then
+ return {verified=false, reason=jwt_obj[str_const.reason]}
+ end
+ return _M.verify_jwt_obj(self, secret, jwt_obj, ...)
+
+end
+
+function _M.set_payload_encoder(self, encoder)
+ if type(encoder) ~= "function" then
+ error({reason="payload encoder must be function"})
+ end
+ self.payload_encoder = encoder
+end
+
+
+function _M.set_payload_decoder(self, decoder)
+ if type(decoder) ~= "function" then
+ error({reason="payload decoder must be function"})
+ end
+ self.payload_decoder= decoder
+end
+
+
+function _M.new()
+ return setmetatable({}, mt)
+end
+
+return _M