aboutsummaryrefslogtreecommitdiffstats
path: root/server/resty/openssl/auxiliary/jwk.lua
blob: 5a505a933b7d2320ed622eff3d8289d4836e3444 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
local ffi = require "ffi"
local C = ffi.C

local cjson = require("cjson.safe")
local b64 = require("ngx.base64")

local evp_macro = require "resty.openssl.include.evp"
local rsa_lib = require "resty.openssl.rsa"
local ec_lib = require "resty.openssl.ec"
local ecx_lib = require "resty.openssl.ecx"
local bn_lib = require "resty.openssl.bn"
local digest_lib = require "resty.openssl.digest"

local _M = {}

local rsa_jwk_params = {"n", "e", "d", "p", "q", "dp", "dq", "qi"}
local rsa_openssl_params = rsa_lib.params

local function load_jwk_rsa(tbl)
  if not tbl["n"] or not tbl["e"] then
    return nil, "at least \"n\" and \"e\" parameter is required"
  end

  local params = {}
  local err
  for i, k in ipairs(rsa_jwk_params) do
    local v = tbl[k]
    if v then
      v = b64.decode_base64url(v)
      if not v then
        return nil, "cannot decode parameter \"" .. k .. "\" from base64 " .. tbl[k]
      end

      params[rsa_openssl_params[i]], err = bn_lib.from_binary(v)
      if err then
        return nil, "cannot use parameter \"" .. k .. "\": " .. err
      end
    end
  end

  local key = C.RSA_new()
  if key == nil then
    return nil, "RSA_new() failed"
  end

  local _, err = rsa_lib.set_parameters(key, params)
  if err ~= nil then
    C.RSA_free(key)
    return nil, err
  end

  return key
end

local ec_curves = {
  ["P-256"] = C.OBJ_ln2nid("prime256v1"),
  ["P-384"] = C.OBJ_ln2nid("secp384r1"),
  ["P-521"] = C.OBJ_ln2nid("secp521r1"),
}

local ec_curves_reverse = {}
for k, v in pairs(ec_curves) do
  ec_curves_reverse[v] = k
end

local ec_jwk_params = {"x", "y", "d"}

local function load_jwk_ec(tbl)
  local curve = tbl['crv']
  if not curve then
    return nil, "\"crv\" not defined for EC key"
  end
  if not tbl["x"] or not tbl["y"] then
    return nil, "at least \"x\" and \"y\" parameter is required"
  end
  local curve_nid = ec_curves[curve]
  if not curve_nid then
    return nil, "curve \"" .. curve .. "\" is not supported by this library"
  elseif curve_nid == 0 then
    return nil, "curve \"" .. curve .. "\" is not supported by linked OpenSSL"
  end

  local params = {}
  local err
  for _, k in ipairs(ec_jwk_params) do
    local v = tbl[k]
    if v then
      v = b64.decode_base64url(v)
      if not v then
        return nil, "cannot decode parameter \"" .. k .. "\" from base64 " .. tbl[k]
      end

      params[k], err = bn_lib.from_binary(v)
      if err then
        return nil, "cannot use parameter \"" .. k .. "\": " .. err
      end
    end
  end

  -- map to the name we expect
  if params["d"] then
    params["private"] = params["d"]
    params["d"] = nil
  end
  params["group"] = curve_nid

  local key = C.EC_KEY_new()
  if key == nil then
    return nil, "EC_KEY_new() failed"
  end

  local _, err = ec_lib.set_parameters(key, params)
  if err ~= nil then
    C.EC_KEY_free(key)
    return nil, err
  end

  return key
end

local function load_jwk_okp(key_type, tbl)
  local params = {}
  if tbl["d"] then
    params.private = b64.decode_base64url(tbl["d"])
  elseif tbl["x"] then
    params.public = b64.decode_base64url(tbl["x"])
  else
    return nil, "at least \"x\" or \"d\" parameter is required"
  end
  local key, err = ecx_lib.set_parameters(key_type, nil, params)
  if err ~= nil then
    return nil, err
  end
  return key
end

local ecx_curves_reverse = {}
for k, v in pairs(evp_macro.ecx_curves) do
  ecx_curves_reverse[v] = k
end

function _M.load_jwk(txt)
  local tbl, err = cjson.decode(txt)
  if err then
    return nil, "error decoding JSON from JWK: " .. err
  elseif type(tbl) ~= "table" then
    return nil, "except input to be decoded as a table, got " .. type(tbl)
  end

  local key, key_free, key_type, err

  if tbl["kty"] == "RSA" then
    key_type = evp_macro.EVP_PKEY_RSA
    if key_type == 0 then
      return nil, "the linked OpenSSL library doesn't support RSA key"
    end
    key, err = load_jwk_rsa(tbl)
    key_free = C.RSA_free
  elseif tbl["kty"] == "EC" then
    key_type = evp_macro.EVP_PKEY_EC
    if key_type == 0 then
      return nil, "the linked OpenSSL library doesn't support EC key"
    end
    key, err = load_jwk_ec(tbl)
    key_free = C.EC_KEY_free
  elseif tbl["kty"] == "OKP" then
    local curve = tbl["crv"]
    key_type = evp_macro.ecx_curves[curve]
    if not key_type then
      return nil, "unknown curve \"" .. tostring(curve)
    elseif key_type == 0 then
      return nil, "the linked OpenSSL library doesn't support \"" .. curve .. "\" key"
    end
    key, err = load_jwk_okp(key_type, tbl)
    if key ~= nil then
      return key
    end
  else
    return nil, "not yet supported jwk type \"" .. (tbl["kty"] or "nil") .. "\""
  end

  if err then
    return nil, "failed to construct " .. tbl["kty"] .. " key from JWK: " .. err
  end

  local ctx = C.EVP_PKEY_new()
  if ctx == nil then
    key_free(key)
    return nil, "EVP_PKEY_new() failed"
  end

  local code = C.EVP_PKEY_assign(ctx, key_type, key)
  if code ~= 1 then
    key_free(key)
    C.EVP_PKEY_free(ctx)
    return nil, "EVP_PKEY_assign() failed"
  end

  return ctx
end

function _M.dump_jwk(pkey, is_priv)
  local jwk
  if pkey.key_type == evp_macro.EVP_PKEY_RSA then
    local param_keys = { "n" , "e" }
    if is_priv then
      param_keys = rsa_jwk_params
    end
    local params, err = pkey:get_parameters()
    if err then
      return nil, "jwk.dump_jwk: " .. err
    end
    jwk = {
      kty = "RSA",
    }
    for i, p in ipairs(param_keys) do
      local v = params[rsa_openssl_params[i]]:to_binary()
      jwk[p] = b64.encode_base64url(v)
    end
  elseif pkey.key_type == evp_macro.EVP_PKEY_EC then
    local params, err = pkey:get_parameters()
    if err then
      return nil, "jwk.dump_jwk: " .. err
    end
    jwk = {
      kty = "EC",
      crv = ec_curves_reverse[params.group],
      x = b64.encode_base64url(params.x:to_binary()),
      y = b64.encode_base64url(params.x:to_binary()),
    }
    if is_priv then
      jwk.d = b64.encode_base64url(params.private:to_binary())
    end
  elseif ecx_curves_reverse[pkey.key_type] then
    local params, err = pkey:get_parameters()
    if err then
      return nil, "jwk.dump_jwk: " .. err
    end
    jwk = {
      kty = "OKP",
      crv = ecx_curves_reverse[pkey.key_type],
      d = b64.encode_base64url(params.private),
      x = b64.encode_base64url(params.public),
    }
  else
    return nil, "jwk.dump_jwk: not implemented for this key type"
  end

  local der = pkey:tostring(is_priv and "private" or "public", "DER")
  local dgst = digest_lib.new("sha256")
  local d, err = dgst:final(der)
  if err then
    return nil, "jwk.dump_jwk: failed to calculate digest for key"
  end
  jwk.kid = b64.encode_base64url(d)

  return cjson.encode(jwk)
end

return _M