aboutsummaryrefslogtreecommitdiffstats
path: root/server/resty/jwt-validators.lua
blob: df9941864f8d72f9b0869a2a00b19b2f6bb83869 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
local _M = { _VERSION = "0.2.3" }

--[[
  This file defines "validators" to be used in validating a spec.  A "validator" is simply a function with
  a signature that matches:

    function(val, claim, jwt_json)

  This function returns either true or false.  If a validator needs to give more information on why it failed,
  then it can also raise an error (which will be used in the "reason" part of the validated jwt_obj).  If a
  validator returns nil, then it is assumed to have passed (same as returning true) and that you just forgot
  to actually return a value.

  There is a special claim name of "__jwt" that can be used to validate the entire jwt_obj.

  "val" is the value being tested.  It may be nil if the claim doesn't exist in the jwt_obj.  If the function
  is being called for the "__jwt" claim, then "val" will contain a deep clone of the full jwt object.

  "claim" is the claim that is being tested.  It is passed in just in case a validator needs to do additional
  checks.  It will be the string "__jwt" if the validator is being called for the entire jwt_object.

  "jwt_json" is a json-encoded representation of the full object that is being tested.  It will never be nil,
  and can always be decoded using cjson.decode(jwt_json).
]]--


--[[
    A function which will define a validator.  It creates both "opt_" and required (non-"opt_")
    versions.  The function that is passed in is the *optional* version.
]]--
local function define_validator(name, fx)
  _M["opt_" .. name] = fx
  _M[name] = function(...) return _M.chain(_M.required(), fx(...)) end
end

-- Validation messages
local messages = {
  nil_validator = "Cannot create validator for nil %s.",
  wrong_type_validator = "Cannot create validator for non-%s %s.",
  empty_table_validator = "Cannot create validator for empty table %s.",
  wrong_table_type_validator = "Cannot create validator for non-%s table %s.",
  required_claim = "'%s' claim is required.",
  wrong_type_claim = "'%s' is malformed.  Expected to be a %s.",
  missing_claim = "Missing one of claims - [ %s ]."
}

-- Local function to make sure that a value is non-nil or raises an error
local function ensure_not_nil(v, e, ...)
  return v ~= nil and v or error(string.format(e, ...), 0)
end

-- Local function to make sure that a value is the given type
local function ensure_is_type(v, t, e, ...)
  return type(v) == t and v or error(string.format(e, ...), 0)
end

-- Local function to make sure that a value is a (non-empty) table
local function ensure_is_table(v, e, ...)
  ensure_is_type(v, "table", e, ...)
  return ensure_not_nil(next(v), e, ...)
end

-- Local function to make sure all entries in the table are the given type
local function ensure_is_table_type(v, t, e, ...)
  if v ~= nil then
    ensure_is_table(v, e, ...)
    for _,val in ipairs(v) do
      ensure_is_type(val, t, e, ...)
    end
  end
  return v
end

-- Local function to ensure that a number is non-negative (positive or 0)
local function ensure_is_non_negative(v, e, ...)
  if v ~= nil then
    ensure_is_type(v, "number", e, ...)
    if v >= 0 then
      return v
    else
      error(string.format(e, ...), 0)
    end
  end
end

-- A local function which returns simple equality
local function equality_function(val, check)
  return val == check
end

-- A local function which returns string match
local function string_match_function(val, pattern)
  return string.match(val, pattern) ~= nil
end

--[[
    A local function which returns truth on existence of check in vals.
    Adopted from auth0/nginx-jwt table_contains by @twistedstream
]]--
local function table_contains_function(vals, check)
    for _, val in pairs(vals) do
        if val == check then return true end
    end
    return false
end


-- A local function which returns numeric greater than comparison
local function greater_than_function(val, check)
  return val > check
end

-- A local function which returns numeric greater than or equal comparison
local function greater_than_or_equal_function(val, check)
  return val >= check
end

-- A local function which returns numeric less than comparison
local function less_than_function(val, check)
  return val < check
end

-- A local function which returns numeric less than or equal comparison
local function less_than_or_equal_function(val, check)
  return val <= check
end


--[[
    Returns a validator that chains the given functions together, one after
    another - as long as they keep passing their checks.
]]--
function _M.chain(...)
  local chain_functions = {...}
  for _, fx in ipairs(chain_functions) do
    ensure_is_type(fx, "function", messages.wrong_type_validator, "function", "chain_function")
  end

  return function(val, claim, jwt_json)
    for _, fx in ipairs(chain_functions) do
      if fx(val, claim, jwt_json) == false then
        return false
      end
    end
    return true
  end
end

--[[
    Returns a validator that returns false if a value doesn't exist.  If
    the value exists and a chain_function is specified, then the value of
        chain_function(val, claim, jwt_json)
    will be returned, otherwise, true will be returned.  This allows for
    specifying that a value is both required *and* it must match some
    additional check.  This function will be used in the "required_*" shortcut
    functions for simplification.
]]--
function _M.required(chain_function)
  if chain_function ~= nil then
    return _M.chain(_M.required(), chain_function)
  end

  return function(val, claim, jwt_json)
    ensure_not_nil(val, messages.required_claim, claim)
    return true
  end
end

--[[
    Returns a validator which errors with a message if *NONE* of the given claim
    keys exist.  It is expected that this function is used against a full jwt object.
    The claim_keys must be a non-empty table of strings.
]]--
function _M.require_one_of(claim_keys)
  ensure_not_nil(claim_keys, messages.nil_validator, "claim_keys")
  ensure_is_type(claim_keys, "table", messages.wrong_type_validator, "table", "claim_keys")
  ensure_is_table(claim_keys, messages.empty_table_validator, "claim_keys")
  ensure_is_table_type(claim_keys, "string", messages.wrong_table_type_validator, "string", "claim_keys")

  return function(val, claim, jwt_json)
    ensure_is_type(val, "table", messages.wrong_type_claim, claim, "table")
    ensure_is_type(val.payload, "table", messages.wrong_type_claim, claim .. ".payload", "table")

    for i, v in ipairs(claim_keys) do
      if val.payload[v] ~= nil then return true end
    end

    error(string.format(messages.missing_claim, table.concat(claim_keys, ", ")), 0)
  end
end

--[[
    Returns a validator that checks if the result of calling the given function for
    the tested value and the check value returns true.  The value of check_val and
    check_function cannot be nil.  The optional name is used for error messages and
    defaults to "check_value".  The optional check_type is used to make sure that
    the check type matches and defaults to type(check_val).  The first parameter
    passed to check_function will *never* be nil (check succeeds if value is nil).
    Use the required version to fail on nil.  If the check_function raises an
    error, that will be appended to the error message.
]]--
define_validator("check", function(check_val, check_function, name, check_type)
  name = name or "check_val"
  ensure_not_nil(check_val, messages.nil_validator, name)

  ensure_not_nil(check_function, messages.nil_validator, "check_function")
  ensure_is_type(check_function, "function", messages.wrong_type_validator, "function", "check_function")

  check_type = check_type or type(check_val)
  return function(val, claim, jwt_json)
    if val == nil then return true end

    ensure_is_type(val, check_type, messages.wrong_type_claim, claim, check_type)
    return check_function(val, check_val)
  end
end)


--[[
    Returns a validator that checks if a value exactly equals the given check_value.
    If the value is nil, then this check succeeds.  The value of check_val cannot be
    nil.
]]--
define_validator("equals", function(check_val)
  return _M.opt_check(check_val, equality_function, "check_val")
end)


--[[
    Returns a validator that checks if a value matches the given pattern.  The value
    of pattern must be a string.
]]--
define_validator("matches", function (pattern)
  ensure_is_type(pattern, "string", messages.wrong_type_validator, "string", "pattern")
  return _M.opt_check(pattern, string_match_function, "pattern", "string")
end)


--[[
    Returns a validator which calls the given function for each of the given values
    and the tested value.  If any of these calls return true, then this function
    returns true.  The value of check_values must be a non-empty table with all the
    same types, and the value of check_function must not be nil.  The optional name
    is used for error messages and defaults to "check_values".  The optional
    check_type is used to make sure that the check type matches and defaults to
    type(check_values[1]) - the table type.
]]--
define_validator("any_of", function(check_values, check_function, name, check_type, table_type)
  name = name or "check_values"
  ensure_not_nil(check_values, messages.nil_validator, name)
  ensure_is_type(check_values, "table", messages.wrong_type_validator, "table", name)
  ensure_is_table(check_values, messages.empty_table_validator, name)

  table_type = table_type or type(check_values[1])
  ensure_is_table_type(check_values, table_type, messages.wrong_table_type_validator, table_type, name)

  ensure_not_nil(check_function, messages.nil_validator, "check_function")
  ensure_is_type(check_function, "function", messages.wrong_type_validator, "function", "check_function")

  check_type = check_type or table_type
  return _M.opt_check(check_values, function(v1, v2)
    for i, v in ipairs(v2) do
      if check_function(v1, v) then return true end
    end
    return false
  end, name, check_type)
end)


--[[
    Returns a validator that checks if a value exactly equals any of the given values.
]]--
define_validator("equals_any_of", function(check_values)
  return _M.opt_any_of(check_values, equality_function, "check_values")
end)


--[[
    Returns a validator that checks if a value matches any of the given patterns.
]]--
define_validator("matches_any_of", function(patterns)
  return _M.opt_any_of(patterns, string_match_function, "patterns", "string", "string")
end)

--[[
    Returns a validator that checks if a value of expected type string exists in any of the given values.
    The value of check_values must be a non-empty table with all the same types.
    The optional name is used for error messages and defaults to "check_values".
]]--
define_validator("contains_any_of", function(check_values, name)
  return _M.opt_any_of(check_values, table_contains_function, name, "table", "string")
end)

--[[
    Returns a validator that checks how a value compares (numerically) to a given
    check_value.  The value of check_val cannot be nil and must be a number.
]]--
define_validator("greater_than", function(check_val)
  ensure_is_type(check_val, "number", messages.wrong_type_validator, "number", "check_val")
  return _M.opt_check(check_val, greater_than_function, "check_val", "number")
end)
define_validator("greater_than_or_equal", function(check_val)
  ensure_is_type(check_val, "number", messages.wrong_type_validator, "number", "check_val")
  return _M.opt_check(check_val, greater_than_or_equal_function, "check_val", "number")
end)
define_validator("less_than", function(check_val)
  ensure_is_type(check_val, "number", messages.wrong_type_validator, "number", "check_val")
  return _M.opt_check(check_val, less_than_function, "check_val", "number")
end)
define_validator("less_than_or_equal", function(check_val)
  ensure_is_type(check_val, "number", messages.wrong_type_validator, "number", "check_val")
  return _M.opt_check(check_val, less_than_or_equal_function, "check_val", "number")
end)


--[[
    A function to set the leeway (in seconds) used for is_not_before and is_not_expired.  The
    default is to use 0 seconds
]]--
local system_leeway = 0
function _M.set_system_leeway(leeway)
  ensure_is_type(leeway, "number", "leeway must be a non-negative number")
  ensure_is_non_negative(leeway, "leeway must be a non-negative number")
  system_leeway = leeway
end


--[[
    A function to set the system clock used for is_not_before and is_not_expired.  The
    default is to use ngx.now
]]--
local system_clock = ngx.now
function _M.set_system_clock(clock)
  ensure_is_type(clock, "function", "clock must be a function")
  -- Check that clock returns the correct value
  local t = clock()
  ensure_is_type(t, "number", "clock function must return a non-negative number")
  ensure_is_non_negative(t, "clock function must return a non-negative number")
  system_clock = clock
end

-- Local helper function for date validation
local function validate_is_date(val, claim, jwt_json)
  ensure_is_non_negative(val, messages.wrong_type_claim, claim, "positive numeric value")
  return true
end

-- Local helper for date formatting
local function format_date_on_error(date_check_function, error_msg)
  ensure_is_type(date_check_function, "function", messages.wrong_type_validator, "function", "date_check_function")
  ensure_is_type(error_msg, "string", messages.wrong_type_validator, "string", error_msg)
  return function(val, claim, jwt_json)
    local ret = date_check_function(val, claim, jwt_json)
    if ret == false then
      error(string.format("'%s' claim %s %s", claim, error_msg, ngx.http_time(val)), 0)
    end
    return true
  end
end

--[[
    Returns a validator that checks if the current time is not before the tested value
    within the system's leeway.  This means that:
      val <= (system_clock() + system_leeway).
]]--
define_validator("is_not_before", function()
  return format_date_on_error(
     _M.chain(validate_is_date,
        function(val)
           return val and less_than_or_equal_function(val, (system_clock() + system_leeway))
        end),
     "not valid until"
  )
end)


--[[
    Returns a validator that checks if the current time is not equal to or after the
    tested value within the system's leeway.  This means that:
      val > (system_clock() - system_leeway).
]]--
define_validator("is_not_expired", function()
  return format_date_on_error(
     _M.chain(validate_is_date,
       function(val)
          return val and greater_than_function(val, (system_clock() - system_leeway))
       end),
     "expired at"
  )
end)

--[[
    Returns a validator that checks if the current time is the same as the tested value
    within the system's leeway.  This means that:
      val >= (system_clock() - system_leeway) and val <= (system_clock() + system_leeway).
]]--
define_validator("is_at", function()
  local now = system_clock()
  return format_date_on_error(
    _M.chain(validate_is_date,
             function(val)
                local now = system_clock()
                return val and
                   greater_than_or_equal_function(val, now - system_leeway) and
                   less_than_or_equal_function(val, now + system_leeway)
             end),
    "is only valid at"
  )
end)


return _M