aboutsummaryrefslogtreecommitdiffstats
path: root/server/resty/session.lua
diff options
context:
space:
mode:
Diffstat (limited to 'server/resty/session.lua')
-rw-r--r--server/resty/session.lua771
1 files changed, 771 insertions, 0 deletions
diff --git a/server/resty/session.lua b/server/resty/session.lua
new file mode 100644
index 0000000..12dfe53
--- /dev/null
+++ b/server/resty/session.lua
@@ -0,0 +1,771 @@
+local require = require
+
+local random = require "resty.random"
+
+local ngx = ngx
+local var = ngx.var
+local time = ngx.time
+local header = ngx.header
+local http_time = ngx.http_time
+local set_header = ngx.req.set_header
+local clear_header = ngx.req.clear_header
+local concat = table.concat
+local ceil = math.ceil
+local max = math.max
+local find = string.find
+local gsub = string.gsub
+local byte = string.byte
+local sub = string.sub
+local type = type
+local pcall = pcall
+local tonumber = tonumber
+local setmetatable = setmetatable
+local getmetatable = getmetatable
+local bytes = random.bytes
+
+local UNDERSCORE = byte("_")
+local EXPIRE_FLAGS = "; Expires=Thu, 01 Jan 1970 00:00:01 GMT; Max-Age=0"
+
+local COOKIE_PARTS = {
+ DEFAULT = {
+ n = 3,
+ "id",
+ "expires", -- may also contain: `expires:usebefore`
+ "hash"
+ },
+ cookie = {
+ n = 4,
+ "id",
+ "expires", -- may also contain: `expires:usebefore`
+ "data",
+ "hash",
+ },
+}
+
+local function enabled(value)
+ if value == nil then
+ return nil
+ end
+
+ return value == true
+ or value == "1"
+ or value == "true"
+ or value == "on"
+end
+
+local function ifnil(value, default)
+ if value == nil then
+ return default
+ end
+
+ return enabled(value)
+end
+
+local function prequire(prefix, package, default)
+ if type(package) == "table" then
+ return package, package.name
+ end
+
+ local ok, module = pcall(require, prefix .. package)
+ if not ok then
+ return require(prefix .. default), default
+ end
+
+ return module, package
+end
+
+local function is_session_cookie(cookie, name, name_len)
+ if not cookie or cookie == "" then
+ return false, nil
+ end
+
+ cookie = gsub(cookie, "^%s+", "")
+ if cookie == "" then
+ return false, nil
+ end
+
+ cookie = gsub(cookie, "%s+$", "")
+ if cookie == "" then
+ return false, nil
+ end
+
+ local eq_pos = find(cookie, "=", 1, true)
+ if not eq_pos then
+ return false, cookie
+ end
+
+ local cookie_name = sub(cookie, 1, eq_pos - 1)
+ if cookie_name == "" then
+ return false, cookie
+ end
+
+ cookie_name = gsub(cookie_name, "%s+$", "")
+ if cookie_name == "" then
+ return false, cookie
+ end
+
+ if cookie_name ~= name then
+ if find(cookie_name, name, 1, true) ~= 1 then
+ return false, cookie
+ end
+
+ if byte(cookie_name, name_len + 1) ~= UNDERSCORE then
+ return false, cookie
+ end
+
+ if not tonumber(sub(cookie_name, name_len + 2), 10) then
+ return false, cookie
+ end
+ end
+
+ return true, cookie
+end
+
+local function set_cookie(session, value, expires)
+ if ngx.headers_sent then
+ return nil, "attempt to set session cookie after sending out response headers"
+ end
+
+ value = value or ""
+
+ local cookie = session.cookie
+ local output = {}
+
+ local i = 3
+
+ -- build cookie parameters, elements 1+2 will be set later
+ if expires then
+ -- we're expiring/deleting the data, so set an expiry in the past
+ output[i] = EXPIRE_FLAGS
+ elseif cookie.persistent then
+ -- persistent cookies have an expiry
+ output[i] = "; Expires=" .. http_time(session.expires) .. "; Max-Age=" .. cookie.lifetime
+ else
+ -- just to reserve index 3 for expiry as cookie might get smaller,
+ -- and some cookies need to be expired.
+ output[i] = ""
+ end
+
+ if cookie.domain and cookie.domain ~= "localhost" and cookie.domain ~= "" then
+ i = i + 1
+ output[i] = "; Domain=" .. cookie.domain
+ end
+
+ i = i + 1
+ output[i] = "; Path=" .. (cookie.path or "/")
+
+ if cookie.samesite == "Lax"
+ or cookie.samesite == "Strict"
+ or cookie.samesite == "None"
+ then
+ i = i + 1
+ output[i] = "; SameSite=" .. cookie.samesite
+ end
+
+ if cookie.secure then
+ i = i + 1
+ output[i] = "; Secure"
+ end
+
+ if cookie.httponly then
+ i = i + 1
+ output[i] = "; HttpOnly"
+ end
+
+ -- How many chunks do we need?
+ local cookie_parts
+ local cookie_chunks
+ if expires then
+ -- expiring cookie, so deleting data. Do not measure data, but use
+ -- existing chunk count to make sure we clear all of them
+ cookie_parts = cookie.chunks or 1
+ else
+ -- calculate required chunks from data
+ cookie_chunks = max(ceil(#value / cookie.maxsize), 1)
+ cookie_parts = max(cookie_chunks, cookie.chunks or 1)
+ end
+
+ local cookie_header = header["Set-Cookie"]
+ for j = 1, cookie_parts do
+ -- create numbered chunk names if required
+ local chunk_name = { session.name }
+ if j > 1 then
+ chunk_name[2] = "_"
+ chunk_name[3] = j
+ chunk_name[4] = "="
+ else
+ chunk_name[2] = "="
+ end
+ chunk_name = concat(chunk_name)
+ output[1] = chunk_name
+
+ if expires then
+ -- expiring cookie, so deleting data; clear it
+ output[2] = ""
+ elseif j > cookie_chunks then
+ -- less chunks than before, clearing excess cookies
+ output[2] = ""
+ output[3] = EXPIRE_FLAGS
+
+ else
+ -- grab the piece for the current chunk
+ local sp = j * cookie.maxsize - (cookie.maxsize - 1)
+ if j < cookie_chunks then
+ output[2] = sub(value, sp, sp + (cookie.maxsize - 1)) .. "0"
+ else
+ output[2] = sub(value, sp)
+ end
+ end
+
+ -- build header value and add it to the header table/string
+ -- replace existing chunk-name, or append
+ local cookie_content = concat(output)
+ local header_type = type(cookie_header)
+ if header_type == "table" then
+ local found = false
+ local cookie_count = #cookie_header
+ for cookie_index = 1, cookie_count do
+ if find(cookie_header[cookie_index], chunk_name, 1, true) == 1 then
+ cookie_header[cookie_index] = cookie_content
+ found = true
+ break
+ end
+ end
+ if not found then
+ cookie_header[cookie_count + 1] = cookie_content
+ end
+ elseif header_type == "string" and find(cookie_header, chunk_name, 1, true) ~= 1 then
+ cookie_header = { cookie_header, cookie_content }
+ else
+ cookie_header = cookie_content
+ end
+ end
+
+ header["Set-Cookie"] = cookie_header
+
+ return true
+end
+
+local function get_cookie(session, i)
+ local cookie_name = { "cookie_", session.name }
+ if i then
+ cookie_name[3] = "_"
+ cookie_name[4] = i
+ else
+ i = 1
+ end
+
+ local cookie = var[concat(cookie_name)]
+ if not cookie then
+ return nil
+ end
+
+ session.cookie.chunks = i
+
+ local cookie_size = #cookie
+ if cookie_size <= session.cookie.maxsize then
+ return cookie
+ end
+
+ return concat{ sub(cookie, 1, session.cookie.maxsize), get_cookie(session, i + 1) or "" }
+end
+
+local function set_usebefore(session)
+ local usebefore = session.usebefore
+ local idletime = session.cookie.idletime
+
+ if idletime == 0 then -- usebefore is disabled
+ if usebefore then
+ session.usebefore = nil
+ return true
+ end
+
+ return false
+ end
+
+ usebefore = usebefore or 0
+
+ local new_usebefore = session.now + idletime
+ if new_usebefore - usebefore > 60 then
+ session.usebefore = new_usebefore
+ return true
+ end
+
+ return false
+end
+
+local function save(session, close)
+ session.expires = session.now + session.cookie.lifetime
+
+ set_usebefore(session)
+
+ local cookie, err = session.strategy.save(session, close)
+ if not cookie then
+ return nil, err or "unable to save session cookie"
+ end
+
+ return set_cookie(session, cookie)
+end
+
+local function touch(session, close)
+ if set_usebefore(session) then
+ -- usebefore was updated, so set cookie
+ local cookie, err = session.strategy.touch(session, close)
+ if not cookie then
+ return nil, err or "unable to touch session cookie"
+ end
+
+ return set_cookie(session, cookie)
+ end
+
+ if close then
+ local ok, err = session.strategy.close(session)
+ if not ok then
+ return nil, err
+ end
+ end
+
+ return true
+end
+
+local function regenerate(session, flush)
+ if session.strategy.destroy then
+ session.strategy.destroy(session)
+ elseif session.strategy.close then
+ session.strategy.close(session)
+ end
+
+ if flush then
+ session.data = {}
+ end
+
+ session.id = session:identifier()
+end
+
+local secret = bytes(32, true) or bytes(32)
+local defaults
+
+local function init()
+ defaults = {
+ name = var.session_name or "session",
+ identifier = var.session_identifier or "random",
+ strategy = var.session_strategy or "default",
+ storage = var.session_storage or "cookie",
+ serializer = var.session_serializer or "json",
+ compressor = var.session_compressor or "none",
+ encoder = var.session_encoder or "base64",
+ cipher = var.session_cipher or "aes",
+ hmac = var.session_hmac or "sha1",
+ cookie = {
+ path = var.session_cookie_path or "/",
+ domain = var.session_cookie_domain,
+ samesite = var.session_cookie_samesite or "Lax",
+ secure = enabled(var.session_cookie_secure),
+ httponly = enabled(var.session_cookie_httponly or true),
+ persistent = enabled(var.session_cookie_persistent or false),
+ discard = tonumber(var.session_cookie_discard, 10) or 10,
+ renew = tonumber(var.session_cookie_renew, 10) or 600,
+ lifetime = tonumber(var.session_cookie_lifetime, 10) or 3600,
+ idletime = tonumber(var.session_cookie_idletime, 10) or 0,
+ maxsize = tonumber(var.session_cookie_maxsize, 10) or 4000,
+
+ }, check = {
+ ssi = enabled(var.session_check_ssi or false),
+ ua = enabled(var.session_check_ua or true),
+ scheme = enabled(var.session_check_scheme or true),
+ addr = enabled(var.session_check_addr or false)
+ }
+ }
+ defaults.secret = var.session_secret or secret
+end
+
+local session = {
+ _VERSION = "3.10"
+}
+
+session.__index = session
+
+function session:get_cookie()
+ return get_cookie(self)
+end
+
+function session:parse_cookie(value)
+ local cookie
+ local cookie_parts = COOKIE_PARTS[self.cookie.storage] or COOKIE_PARTS.DEFAULT
+
+ local count = 1
+ local pos = 1
+
+ local p_pos = find(value, "|", 1, true)
+ while p_pos do
+ if count > (cookie_parts.n - 1) then
+ return nil, "too many session cookie parts"
+ end
+ if not cookie then
+ cookie = {}
+ end
+
+ if count == 2 then
+ local cookie_part = sub(value, pos, p_pos - 1)
+ local c_pos = find(cookie_part, ":", 2, true)
+ if c_pos then
+ cookie.expires = tonumber(sub(cookie_part, 1, c_pos - 1), 10)
+ if not cookie.expires then
+ return nil, "invalid session cookie expiry"
+ end
+
+ cookie.usebefore = tonumber(sub(cookie_part, c_pos + 1), 10)
+ if not cookie.usebefore then
+ return nil, "invalid session cookie usebefore"
+ end
+ else
+ cookie.expires = tonumber(cookie_part, 10)
+ if not cookie.expires then
+ return nil, "invalid session cookie expiry"
+ end
+ end
+ else
+ local name = cookie_parts[count]
+
+ local cookie_part = self.encoder.decode(sub(value, pos, p_pos - 1))
+ if not cookie_part then
+ return nil, "unable to decode session cookie part (" .. name .. ")"
+ end
+
+ cookie[name] = cookie_part
+ end
+
+ count = count + 1
+ pos = p_pos + 1
+
+ p_pos = find(value, "|", pos, true)
+ end
+
+ if count ~= cookie_parts.n then
+ return nil, "invalid number of session cookie parts"
+ end
+
+ local name = cookie_parts[count]
+
+ local cookie_part = self.encoder.decode(sub(value, pos))
+ if not cookie_part then
+ return nil, "unable to decode session cookie part (" .. name .. ")"
+ end
+
+ cookie[name] = cookie_part
+
+ if not cookie.id then
+ return nil, "missing session cookie id"
+ end
+
+ if not cookie.expires then
+ return nil, "missing session cookie expiry"
+ end
+
+ if cookie.expires <= self.now then
+ return nil, "session cookie has expired"
+ end
+
+ if cookie.usebefore and cookie.usebefore <= self.now then
+ return nil, "session cookie idle time has passed"
+ end
+
+ if not cookie.hash then
+ return nil, "missing session cookie signature"
+ end
+
+ return cookie
+end
+
+function session.new(opts)
+ if opts and getmetatable(opts) == session then
+ return opts
+ end
+
+ if not defaults then
+ init()
+ end
+
+ opts = type(opts) == "table" and opts or defaults
+
+ local cookie = opts.cookie or defaults.cookie
+ local name = opts.name or defaults.name
+ local sec = opts.secret or defaults.secret
+
+ local secure
+ local path
+ local domain
+ if find(name, "__Host-", 1, true) == 1 then
+ secure = true
+ path = "/"
+ else
+ if find(name, "__Secure-", 1, true) == 1 then
+ secure = true
+ else
+ secure = ifnil(cookie.secure, defaults.cookie.secure)
+ end
+
+ domain = cookie.domain or defaults.cookie.domain
+ path = cookie.path or defaults.cookie.path
+ end
+
+ local check = opts.check or defaults.check
+
+ local ide, iden = prequire("resty.session.identifiers.", opts.identifier or defaults.identifier, "random")
+ local ser, sern = prequire("resty.session.serializers.", opts.serializer or defaults.serializer, "json")
+ local com, comn = prequire("resty.session.compressors.", opts.compressor or defaults.compressor, "none")
+ local enc, encn = prequire("resty.session.encoders.", opts.encoder or defaults.encoder, "base64")
+ local cip, cipn = prequire("resty.session.ciphers.", opts.cipher or defaults.cipher, "aes")
+ local sto, ston = prequire("resty.session.storage.", opts.storage or defaults.storage, "cookie")
+ local str, strn = prequire("resty.session.strategies.", opts.strategy or defaults.strategy, "default")
+ local hma, hman = prequire("resty.session.hmac.", opts.hmac or defaults.hmac, "sha1")
+
+ local self = {
+ now = time(),
+ name = name,
+ secret = sec,
+ identifier = ide,
+ serializer = ser,
+ strategy = str,
+ encoder = enc,
+ hmac = hma,
+ cookie = {
+ storage = ston,
+ encoder = enc,
+ path = path,
+ domain = domain,
+ secure = secure,
+ samesite = cookie.samesite or defaults.cookie.samesite,
+ httponly = ifnil(cookie.httponly, defaults.cookie.httponly),
+ persistent = ifnil(cookie.persistent, defaults.cookie.persistent),
+ discard = tonumber(cookie.discard, 10) or defaults.cookie.discard,
+ renew = tonumber(cookie.renew, 10) or defaults.cookie.renew,
+ lifetime = tonumber(cookie.lifetime, 10) or defaults.cookie.lifetime,
+ idletime = tonumber(cookie.idletime, 10) or defaults.cookie.idletime,
+ maxsize = tonumber(cookie.maxsize, 10) or defaults.cookie.maxsize,
+ }, check = {
+ ssi = ifnil(check.ssi, defaults.check.ssi),
+ ua = ifnil(check.ua, defaults.check.ua),
+ scheme = ifnil(check.scheme, defaults.check.scheme),
+ addr = ifnil(check.addr, defaults.check.addr),
+ }
+ }
+ if self.cookie.idletime > 0 and self.cookie.discard > self.cookie.idletime then
+ -- if using idletime, then the discard period must be less or equal
+ self.cookie.discard = self.cookie.idletime
+ end
+
+ if iden and not self[iden] then self[iden] = opts[iden] end
+ if sern and not self[sern] then self[sern] = opts[sern] end
+ if comn and not self[comn] then self[comn] = opts[comn] end
+ if encn and not self[encn] then self[encn] = opts[encn] end
+ if cipn and not self[cipn] then self[cipn] = opts[cipn] end
+ if ston and not self[ston] then self[ston] = opts[ston] end
+ if strn and not self[strn] then self[strn] = opts[strn] end
+ if hman and not self[hman] then self[hman] = opts[hman] end
+
+ self.cipher = cip.new(self)
+ self.storage = sto.new(self)
+ self.compressor = com.new(self)
+
+ return setmetatable(self, session)
+end
+
+function session.open(opts, keep_lock)
+ local self = opts
+ if self and getmetatable(self) == session then
+ if self.opened then
+ return self, self.present
+ end
+ else
+ self = session.new(opts)
+ end
+
+ if self.cookie.secure == nil then
+ self.cookie.secure = var.scheme == "https" or var.https == "on"
+ end
+
+ self.now = time()
+ self.key = concat {
+ self.check.ssi and var.ssl_session_id or "",
+ self.check.ua and var.http_user_agent or "",
+ self.check.addr and var.remote_addr or "",
+ self.check.scheme and var.scheme or "",
+ }
+
+ self.opened = true
+
+ local err
+ local cookie = self:get_cookie()
+ if cookie then
+ cookie, err = self:parse_cookie(cookie)
+ if cookie then
+ local ok
+ ok, err = self.strategy.open(self, cookie, keep_lock)
+ if ok then
+ return self, true
+ end
+ end
+ end
+
+ regenerate(self, true)
+
+ return self, false, err
+end
+
+function session.start(opts)
+ if opts and getmetatable(opts) == session and opts.started then
+ return opts, opts.present
+ end
+
+ local self, present, reason = session.open(opts, true)
+
+ self.started = true
+
+ if not present then
+ local ok, err = save(self)
+ if not ok then
+ return nil, err or "unable to save session cookie"
+ end
+
+ return self, present, reason
+ end
+
+ if self.strategy.start then
+ local ok, err = self.strategy.start(self)
+ if not ok then
+ return nil, err or "unable to start session"
+ end
+ end
+
+ if self.expires - self.now < self.cookie.renew
+ or self.expires > self.now + self.cookie.lifetime
+ then
+ local ok, err = save(self)
+ if not ok then
+ return nil, err or "unable to save session cookie"
+ end
+ else
+ -- we're not saving, so we must touch to update idletime/usebefore
+ local ok, err = touch(self)
+ if not ok then
+ return nil, err or "unable to touch session cookie"
+ end
+ end
+
+ return self, true
+end
+
+function session.destroy(opts)
+ if opts and getmetatable(opts) == session and opts.destroyed then
+ return true
+ end
+
+ local self, err = session.start(opts)
+ if not self then
+ return nil, err
+ end
+
+ if self.strategy.destroy then
+ self.strategy.destroy(self)
+ elseif self.strategy.close then
+ self.strategy.close(self)
+ end
+
+ self.data = {}
+ self.present = nil
+ self.opened = nil
+ self.started = nil
+ self.closed = true
+ self.destroyed = true
+
+ return set_cookie(self, "", true)
+end
+
+function session:regenerate(flush, close)
+ close = close ~= false
+ if self.strategy.regenerate then
+ if flush then
+ self.data = {}
+ end
+
+ if not self.id then
+ self.id = session:identifier()
+ end
+ else
+ regenerate(self, flush)
+ end
+
+ return save(self, close)
+end
+
+function session:save(close)
+ close = close ~= false
+
+ if not self.id then
+ self.id = self:identifier()
+ end
+
+ return save(self, close)
+end
+
+function session:close()
+ self.closed = true
+
+ if self.strategy.close then
+ return self.strategy.close(self)
+ end
+
+ return true
+end
+
+function session:hide()
+ local cookies = var.http_cookie
+ if not cookies or cookies == "" then
+ return
+ end
+
+ local results = {}
+ local name = self.name
+ local name_len = #name
+ local found
+ local i = 1
+ local j = 0
+ local sc_pos = find(cookies, ";", i, true)
+ while sc_pos do
+ local isc, cookie = is_session_cookie(sub(cookies, i, sc_pos - 1), name, name_len)
+ if isc then
+ found = true
+ elseif cookie then
+ j = j + 1
+ results[j] = cookie
+ end
+
+ i = sc_pos + 1
+ sc_pos = find(cookies, ";", i, true)
+ end
+
+ local isc, cookie
+ if i == 1 then
+ isc, cookie = is_session_cookie(cookies, name, name_len)
+ else
+ isc, cookie = is_session_cookie(sub(cookies, i), name, name_len)
+ end
+
+ if not isc and cookie then
+ if not found then
+ return
+ end
+
+ j = j + 1
+ results[j] = cookie
+ end
+
+ if j == 0 then
+ clear_header("Cookie")
+ else
+ set_header("Cookie", concat(results, "; ", 1, j))
+ end
+end
+
+return session