diff options
Diffstat (limited to 'server/resty/openssl/auxiliary/jwk.lua')
-rw-r--r-- | server/resty/openssl/auxiliary/jwk.lua | 261 |
1 files changed, 261 insertions, 0 deletions
diff --git a/server/resty/openssl/auxiliary/jwk.lua b/server/resty/openssl/auxiliary/jwk.lua new file mode 100644 index 0000000..5a505a9 --- /dev/null +++ b/server/resty/openssl/auxiliary/jwk.lua @@ -0,0 +1,261 @@ + +local ffi = require "ffi" +local C = ffi.C + +local cjson = require("cjson.safe") +local b64 = require("ngx.base64") + +local evp_macro = require "resty.openssl.include.evp" +local rsa_lib = require "resty.openssl.rsa" +local ec_lib = require "resty.openssl.ec" +local ecx_lib = require "resty.openssl.ecx" +local bn_lib = require "resty.openssl.bn" +local digest_lib = require "resty.openssl.digest" + +local _M = {} + +local rsa_jwk_params = {"n", "e", "d", "p", "q", "dp", "dq", "qi"} +local rsa_openssl_params = rsa_lib.params + +local function load_jwk_rsa(tbl) + if not tbl["n"] or not tbl["e"] then + return nil, "at least \"n\" and \"e\" parameter is required" + end + + local params = {} + local err + for i, k in ipairs(rsa_jwk_params) do + local v = tbl[k] + if v then + v = b64.decode_base64url(v) + if not v then + return nil, "cannot decode parameter \"" .. k .. "\" from base64 " .. tbl[k] + end + + params[rsa_openssl_params[i]], err = bn_lib.from_binary(v) + if err then + return nil, "cannot use parameter \"" .. k .. "\": " .. err + end + end + end + + local key = C.RSA_new() + if key == nil then + return nil, "RSA_new() failed" + end + + local _, err = rsa_lib.set_parameters(key, params) + if err ~= nil then + C.RSA_free(key) + return nil, err + end + + return key +end + +local ec_curves = { + ["P-256"] = C.OBJ_ln2nid("prime256v1"), + ["P-384"] = C.OBJ_ln2nid("secp384r1"), + ["P-521"] = C.OBJ_ln2nid("secp521r1"), +} + +local ec_curves_reverse = {} +for k, v in pairs(ec_curves) do + ec_curves_reverse[v] = k +end + +local ec_jwk_params = {"x", "y", "d"} + +local function load_jwk_ec(tbl) + local curve = tbl['crv'] + if not curve then + return nil, "\"crv\" not defined for EC key" + end + if not tbl["x"] or not tbl["y"] then + return nil, "at least \"x\" and \"y\" parameter is required" + end + local curve_nid = ec_curves[curve] + if not curve_nid then + return nil, "curve \"" .. curve .. "\" is not supported by this library" + elseif curve_nid == 0 then + return nil, "curve \"" .. curve .. "\" is not supported by linked OpenSSL" + end + + local params = {} + local err + for _, k in ipairs(ec_jwk_params) do + local v = tbl[k] + if v then + v = b64.decode_base64url(v) + if not v then + return nil, "cannot decode parameter \"" .. k .. "\" from base64 " .. tbl[k] + end + + params[k], err = bn_lib.from_binary(v) + if err then + return nil, "cannot use parameter \"" .. k .. "\": " .. err + end + end + end + + -- map to the name we expect + if params["d"] then + params["private"] = params["d"] + params["d"] = nil + end + params["group"] = curve_nid + + local key = C.EC_KEY_new() + if key == nil then + return nil, "EC_KEY_new() failed" + end + + local _, err = ec_lib.set_parameters(key, params) + if err ~= nil then + C.EC_KEY_free(key) + return nil, err + end + + return key +end + +local function load_jwk_okp(key_type, tbl) + local params = {} + if tbl["d"] then + params.private = b64.decode_base64url(tbl["d"]) + elseif tbl["x"] then + params.public = b64.decode_base64url(tbl["x"]) + else + return nil, "at least \"x\" or \"d\" parameter is required" + end + local key, err = ecx_lib.set_parameters(key_type, nil, params) + if err ~= nil then + return nil, err + end + return key +end + +local ecx_curves_reverse = {} +for k, v in pairs(evp_macro.ecx_curves) do + ecx_curves_reverse[v] = k +end + +function _M.load_jwk(txt) + local tbl, err = cjson.decode(txt) + if err then + return nil, "error decoding JSON from JWK: " .. err + elseif type(tbl) ~= "table" then + return nil, "except input to be decoded as a table, got " .. type(tbl) + end + + local key, key_free, key_type, err + + if tbl["kty"] == "RSA" then + key_type = evp_macro.EVP_PKEY_RSA + if key_type == 0 then + return nil, "the linked OpenSSL library doesn't support RSA key" + end + key, err = load_jwk_rsa(tbl) + key_free = C.RSA_free + elseif tbl["kty"] == "EC" then + key_type = evp_macro.EVP_PKEY_EC + if key_type == 0 then + return nil, "the linked OpenSSL library doesn't support EC key" + end + key, err = load_jwk_ec(tbl) + key_free = C.EC_KEY_free + elseif tbl["kty"] == "OKP" then + local curve = tbl["crv"] + key_type = evp_macro.ecx_curves[curve] + if not key_type then + return nil, "unknown curve \"" .. tostring(curve) + elseif key_type == 0 then + return nil, "the linked OpenSSL library doesn't support \"" .. curve .. "\" key" + end + key, err = load_jwk_okp(key_type, tbl) + if key ~= nil then + return key + end + else + return nil, "not yet supported jwk type \"" .. (tbl["kty"] or "nil") .. "\"" + end + + if err then + return nil, "failed to construct " .. tbl["kty"] .. " key from JWK: " .. err + end + + local ctx = C.EVP_PKEY_new() + if ctx == nil then + key_free(key) + return nil, "EVP_PKEY_new() failed" + end + + local code = C.EVP_PKEY_assign(ctx, key_type, key) + if code ~= 1 then + key_free(key) + C.EVP_PKEY_free(ctx) + return nil, "EVP_PKEY_assign() failed" + end + + return ctx +end + +function _M.dump_jwk(pkey, is_priv) + local jwk + if pkey.key_type == evp_macro.EVP_PKEY_RSA then + local param_keys = { "n" , "e" } + if is_priv then + param_keys = rsa_jwk_params + end + local params, err = pkey:get_parameters() + if err then + return nil, "jwk.dump_jwk: " .. err + end + jwk = { + kty = "RSA", + } + for i, p in ipairs(param_keys) do + local v = params[rsa_openssl_params[i]]:to_binary() + jwk[p] = b64.encode_base64url(v) + end + elseif pkey.key_type == evp_macro.EVP_PKEY_EC then + local params, err = pkey:get_parameters() + if err then + return nil, "jwk.dump_jwk: " .. err + end + jwk = { + kty = "EC", + crv = ec_curves_reverse[params.group], + x = b64.encode_base64url(params.x:to_binary()), + y = b64.encode_base64url(params.x:to_binary()), + } + if is_priv then + jwk.d = b64.encode_base64url(params.private:to_binary()) + end + elseif ecx_curves_reverse[pkey.key_type] then + local params, err = pkey:get_parameters() + if err then + return nil, "jwk.dump_jwk: " .. err + end + jwk = { + kty = "OKP", + crv = ecx_curves_reverse[pkey.key_type], + d = b64.encode_base64url(params.private), + x = b64.encode_base64url(params.public), + } + else + return nil, "jwk.dump_jwk: not implemented for this key type" + end + + local der = pkey:tostring(is_priv and "private" or "public", "DER") + local dgst = digest_lib.new("sha256") + local d, err = dgst:final(der) + if err then + return nil, "jwk.dump_jwk: failed to calculate digest for key" + end + jwk.kid = b64.encode_base64url(d) + + return cjson.encode(jwk) +end + +return _M |