aboutsummaryrefslogtreecommitdiffstats
path: root/server/resty/jwt-validators.lua
diff options
context:
space:
mode:
Diffstat (limited to 'server/resty/jwt-validators.lua')
-rw-r--r--server/resty/jwt-validators.lua412
1 files changed, 412 insertions, 0 deletions
diff --git a/server/resty/jwt-validators.lua b/server/resty/jwt-validators.lua
new file mode 100644
index 0000000..df99418
--- /dev/null
+++ b/server/resty/jwt-validators.lua
@@ -0,0 +1,412 @@
+local _M = { _VERSION = "0.2.3" }
+
+--[[
+ This file defines "validators" to be used in validating a spec. A "validator" is simply a function with
+ a signature that matches:
+
+ function(val, claim, jwt_json)
+
+ This function returns either true or false. If a validator needs to give more information on why it failed,
+ then it can also raise an error (which will be used in the "reason" part of the validated jwt_obj). If a
+ validator returns nil, then it is assumed to have passed (same as returning true) and that you just forgot
+ to actually return a value.
+
+ There is a special claim name of "__jwt" that can be used to validate the entire jwt_obj.
+
+ "val" is the value being tested. It may be nil if the claim doesn't exist in the jwt_obj. If the function
+ is being called for the "__jwt" claim, then "val" will contain a deep clone of the full jwt object.
+
+ "claim" is the claim that is being tested. It is passed in just in case a validator needs to do additional
+ checks. It will be the string "__jwt" if the validator is being called for the entire jwt_object.
+
+ "jwt_json" is a json-encoded representation of the full object that is being tested. It will never be nil,
+ and can always be decoded using cjson.decode(jwt_json).
+]]--
+
+
+--[[
+ A function which will define a validator. It creates both "opt_" and required (non-"opt_")
+ versions. The function that is passed in is the *optional* version.
+]]--
+local function define_validator(name, fx)
+ _M["opt_" .. name] = fx
+ _M[name] = function(...) return _M.chain(_M.required(), fx(...)) end
+end
+
+-- Validation messages
+local messages = {
+ nil_validator = "Cannot create validator for nil %s.",
+ wrong_type_validator = "Cannot create validator for non-%s %s.",
+ empty_table_validator = "Cannot create validator for empty table %s.",
+ wrong_table_type_validator = "Cannot create validator for non-%s table %s.",
+ required_claim = "'%s' claim is required.",
+ wrong_type_claim = "'%s' is malformed. Expected to be a %s.",
+ missing_claim = "Missing one of claims - [ %s ]."
+}
+
+-- Local function to make sure that a value is non-nil or raises an error
+local function ensure_not_nil(v, e, ...)
+ return v ~= nil and v or error(string.format(e, ...), 0)
+end
+
+-- Local function to make sure that a value is the given type
+local function ensure_is_type(v, t, e, ...)
+ return type(v) == t and v or error(string.format(e, ...), 0)
+end
+
+-- Local function to make sure that a value is a (non-empty) table
+local function ensure_is_table(v, e, ...)
+ ensure_is_type(v, "table", e, ...)
+ return ensure_not_nil(next(v), e, ...)
+end
+
+-- Local function to make sure all entries in the table are the given type
+local function ensure_is_table_type(v, t, e, ...)
+ if v ~= nil then
+ ensure_is_table(v, e, ...)
+ for _,val in ipairs(v) do
+ ensure_is_type(val, t, e, ...)
+ end
+ end
+ return v
+end
+
+-- Local function to ensure that a number is non-negative (positive or 0)
+local function ensure_is_non_negative(v, e, ...)
+ if v ~= nil then
+ ensure_is_type(v, "number", e, ...)
+ if v >= 0 then
+ return v
+ else
+ error(string.format(e, ...), 0)
+ end
+ end
+end
+
+-- A local function which returns simple equality
+local function equality_function(val, check)
+ return val == check
+end
+
+-- A local function which returns string match
+local function string_match_function(val, pattern)
+ return string.match(val, pattern) ~= nil
+end
+
+--[[
+ A local function which returns truth on existence of check in vals.
+ Adopted from auth0/nginx-jwt table_contains by @twistedstream
+]]--
+local function table_contains_function(vals, check)
+ for _, val in pairs(vals) do
+ if val == check then return true end
+ end
+ return false
+end
+
+
+-- A local function which returns numeric greater than comparison
+local function greater_than_function(val, check)
+ return val > check
+end
+
+-- A local function which returns numeric greater than or equal comparison
+local function greater_than_or_equal_function(val, check)
+ return val >= check
+end
+
+-- A local function which returns numeric less than comparison
+local function less_than_function(val, check)
+ return val < check
+end
+
+-- A local function which returns numeric less than or equal comparison
+local function less_than_or_equal_function(val, check)
+ return val <= check
+end
+
+
+--[[
+ Returns a validator that chains the given functions together, one after
+ another - as long as they keep passing their checks.
+]]--
+function _M.chain(...)
+ local chain_functions = {...}
+ for _, fx in ipairs(chain_functions) do
+ ensure_is_type(fx, "function", messages.wrong_type_validator, "function", "chain_function")
+ end
+
+ return function(val, claim, jwt_json)
+ for _, fx in ipairs(chain_functions) do
+ if fx(val, claim, jwt_json) == false then
+ return false
+ end
+ end
+ return true
+ end
+end
+
+--[[
+ Returns a validator that returns false if a value doesn't exist. If
+ the value exists and a chain_function is specified, then the value of
+ chain_function(val, claim, jwt_json)
+ will be returned, otherwise, true will be returned. This allows for
+ specifying that a value is both required *and* it must match some
+ additional check. This function will be used in the "required_*" shortcut
+ functions for simplification.
+]]--
+function _M.required(chain_function)
+ if chain_function ~= nil then
+ return _M.chain(_M.required(), chain_function)
+ end
+
+ return function(val, claim, jwt_json)
+ ensure_not_nil(val, messages.required_claim, claim)
+ return true
+ end
+end
+
+--[[
+ Returns a validator which errors with a message if *NONE* of the given claim
+ keys exist. It is expected that this function is used against a full jwt object.
+ The claim_keys must be a non-empty table of strings.
+]]--
+function _M.require_one_of(claim_keys)
+ ensure_not_nil(claim_keys, messages.nil_validator, "claim_keys")
+ ensure_is_type(claim_keys, "table", messages.wrong_type_validator, "table", "claim_keys")
+ ensure_is_table(claim_keys, messages.empty_table_validator, "claim_keys")
+ ensure_is_table_type(claim_keys, "string", messages.wrong_table_type_validator, "string", "claim_keys")
+
+ return function(val, claim, jwt_json)
+ ensure_is_type(val, "table", messages.wrong_type_claim, claim, "table")
+ ensure_is_type(val.payload, "table", messages.wrong_type_claim, claim .. ".payload", "table")
+
+ for i, v in ipairs(claim_keys) do
+ if val.payload[v] ~= nil then return true end
+ end
+
+ error(string.format(messages.missing_claim, table.concat(claim_keys, ", ")), 0)
+ end
+end
+
+--[[
+ Returns a validator that checks if the result of calling the given function for
+ the tested value and the check value returns true. The value of check_val and
+ check_function cannot be nil. The optional name is used for error messages and
+ defaults to "check_value". The optional check_type is used to make sure that
+ the check type matches and defaults to type(check_val). The first parameter
+ passed to check_function will *never* be nil (check succeeds if value is nil).
+ Use the required version to fail on nil. If the check_function raises an
+ error, that will be appended to the error message.
+]]--
+define_validator("check", function(check_val, check_function, name, check_type)
+ name = name or "check_val"
+ ensure_not_nil(check_val, messages.nil_validator, name)
+
+ ensure_not_nil(check_function, messages.nil_validator, "check_function")
+ ensure_is_type(check_function, "function", messages.wrong_type_validator, "function", "check_function")
+
+ check_type = check_type or type(check_val)
+ return function(val, claim, jwt_json)
+ if val == nil then return true end
+
+ ensure_is_type(val, check_type, messages.wrong_type_claim, claim, check_type)
+ return check_function(val, check_val)
+ end
+end)
+
+
+--[[
+ Returns a validator that checks if a value exactly equals the given check_value.
+ If the value is nil, then this check succeeds. The value of check_val cannot be
+ nil.
+]]--
+define_validator("equals", function(check_val)
+ return _M.opt_check(check_val, equality_function, "check_val")
+end)
+
+
+--[[
+ Returns a validator that checks if a value matches the given pattern. The value
+ of pattern must be a string.
+]]--
+define_validator("matches", function (pattern)
+ ensure_is_type(pattern, "string", messages.wrong_type_validator, "string", "pattern")
+ return _M.opt_check(pattern, string_match_function, "pattern", "string")
+end)
+
+
+--[[
+ Returns a validator which calls the given function for each of the given values
+ and the tested value. If any of these calls return true, then this function
+ returns true. The value of check_values must be a non-empty table with all the
+ same types, and the value of check_function must not be nil. The optional name
+ is used for error messages and defaults to "check_values". The optional
+ check_type is used to make sure that the check type matches and defaults to
+ type(check_values[1]) - the table type.
+]]--
+define_validator("any_of", function(check_values, check_function, name, check_type, table_type)
+ name = name or "check_values"
+ ensure_not_nil(check_values, messages.nil_validator, name)
+ ensure_is_type(check_values, "table", messages.wrong_type_validator, "table", name)
+ ensure_is_table(check_values, messages.empty_table_validator, name)
+
+ table_type = table_type or type(check_values[1])
+ ensure_is_table_type(check_values, table_type, messages.wrong_table_type_validator, table_type, name)
+
+ ensure_not_nil(check_function, messages.nil_validator, "check_function")
+ ensure_is_type(check_function, "function", messages.wrong_type_validator, "function", "check_function")
+
+ check_type = check_type or table_type
+ return _M.opt_check(check_values, function(v1, v2)
+ for i, v in ipairs(v2) do
+ if check_function(v1, v) then return true end
+ end
+ return false
+ end, name, check_type)
+end)
+
+
+--[[
+ Returns a validator that checks if a value exactly equals any of the given values.
+]]--
+define_validator("equals_any_of", function(check_values)
+ return _M.opt_any_of(check_values, equality_function, "check_values")
+end)
+
+
+--[[
+ Returns a validator that checks if a value matches any of the given patterns.
+]]--
+define_validator("matches_any_of", function(patterns)
+ return _M.opt_any_of(patterns, string_match_function, "patterns", "string", "string")
+end)
+
+--[[
+ Returns a validator that checks if a value of expected type string exists in any of the given values.
+ The value of check_values must be a non-empty table with all the same types.
+ The optional name is used for error messages and defaults to "check_values".
+]]--
+define_validator("contains_any_of", function(check_values, name)
+ return _M.opt_any_of(check_values, table_contains_function, name, "table", "string")
+end)
+
+--[[
+ Returns a validator that checks how a value compares (numerically) to a given
+ check_value. The value of check_val cannot be nil and must be a number.
+]]--
+define_validator("greater_than", function(check_val)
+ ensure_is_type(check_val, "number", messages.wrong_type_validator, "number", "check_val")
+ return _M.opt_check(check_val, greater_than_function, "check_val", "number")
+end)
+define_validator("greater_than_or_equal", function(check_val)
+ ensure_is_type(check_val, "number", messages.wrong_type_validator, "number", "check_val")
+ return _M.opt_check(check_val, greater_than_or_equal_function, "check_val", "number")
+end)
+define_validator("less_than", function(check_val)
+ ensure_is_type(check_val, "number", messages.wrong_type_validator, "number", "check_val")
+ return _M.opt_check(check_val, less_than_function, "check_val", "number")
+end)
+define_validator("less_than_or_equal", function(check_val)
+ ensure_is_type(check_val, "number", messages.wrong_type_validator, "number", "check_val")
+ return _M.opt_check(check_val, less_than_or_equal_function, "check_val", "number")
+end)
+
+
+--[[
+ A function to set the leeway (in seconds) used for is_not_before and is_not_expired. The
+ default is to use 0 seconds
+]]--
+local system_leeway = 0
+function _M.set_system_leeway(leeway)
+ ensure_is_type(leeway, "number", "leeway must be a non-negative number")
+ ensure_is_non_negative(leeway, "leeway must be a non-negative number")
+ system_leeway = leeway
+end
+
+
+--[[
+ A function to set the system clock used for is_not_before and is_not_expired. The
+ default is to use ngx.now
+]]--
+local system_clock = ngx.now
+function _M.set_system_clock(clock)
+ ensure_is_type(clock, "function", "clock must be a function")
+ -- Check that clock returns the correct value
+ local t = clock()
+ ensure_is_type(t, "number", "clock function must return a non-negative number")
+ ensure_is_non_negative(t, "clock function must return a non-negative number")
+ system_clock = clock
+end
+
+-- Local helper function for date validation
+local function validate_is_date(val, claim, jwt_json)
+ ensure_is_non_negative(val, messages.wrong_type_claim, claim, "positive numeric value")
+ return true
+end
+
+-- Local helper for date formatting
+local function format_date_on_error(date_check_function, error_msg)
+ ensure_is_type(date_check_function, "function", messages.wrong_type_validator, "function", "date_check_function")
+ ensure_is_type(error_msg, "string", messages.wrong_type_validator, "string", error_msg)
+ return function(val, claim, jwt_json)
+ local ret = date_check_function(val, claim, jwt_json)
+ if ret == false then
+ error(string.format("'%s' claim %s %s", claim, error_msg, ngx.http_time(val)), 0)
+ end
+ return true
+ end
+end
+
+--[[
+ Returns a validator that checks if the current time is not before the tested value
+ within the system's leeway. This means that:
+ val <= (system_clock() + system_leeway).
+]]--
+define_validator("is_not_before", function()
+ return format_date_on_error(
+ _M.chain(validate_is_date,
+ function(val)
+ return val and less_than_or_equal_function(val, (system_clock() + system_leeway))
+ end),
+ "not valid until"
+ )
+end)
+
+
+--[[
+ Returns a validator that checks if the current time is not equal to or after the
+ tested value within the system's leeway. This means that:
+ val > (system_clock() - system_leeway).
+]]--
+define_validator("is_not_expired", function()
+ return format_date_on_error(
+ _M.chain(validate_is_date,
+ function(val)
+ return val and greater_than_function(val, (system_clock() - system_leeway))
+ end),
+ "expired at"
+ )
+end)
+
+--[[
+ Returns a validator that checks if the current time is the same as the tested value
+ within the system's leeway. This means that:
+ val >= (system_clock() - system_leeway) and val <= (system_clock() + system_leeway).
+]]--
+define_validator("is_at", function()
+ local now = system_clock()
+ return format_date_on_error(
+ _M.chain(validate_is_date,
+ function(val)
+ local now = system_clock()
+ return val and
+ greater_than_or_equal_function(val, now - system_leeway) and
+ less_than_or_equal_function(val, now + system_leeway)
+ end),
+ "is only valid at"
+ )
+end)
+
+
+return _M