aboutsummaryrefslogtreecommitdiffstats
path: root/server/resty/session/ciphers/aes.lua
blob: 9a088addf5b504d5528ed5d1c5121d1e454caaca (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
local aes          = require "resty.aes"

local setmetatable = setmetatable
local tonumber     = tonumber
local ceil         = math.ceil
local var          = ngx.var
local sub          = string.sub
local rep          = string.rep

local HASHES       = aes.hash

local CIPHER_MODES = {
    ecb    = "ecb",
    cbc    = "cbc",
    cfb1   = "cfb1",
    cfb8   = "cfb8",
    cfb128 = "cfb128",
    ofb    = "ofb",
    ctr    = "ctr",
    gcm    = "gcm",
}

local CIPHER_SIZES = {
    [128]   = 128,
    [192]   = 192,
    [256]   = 256,
}

local defaults = {
    size   = CIPHER_SIZES[tonumber(var.session_aes_size, 10)] or 256,
    mode   = CIPHER_MODES[var.session_aes_mode]               or "cbc",
    hash   = HASHES[var.session_aes_hash]                     or HASHES.sha512,
    rounds = tonumber(var.session_aes_rounds,            10)  or 1,
}

local function adjust_salt(salt)
    if not salt then
        return nil
    end

    local z = #salt
    if z < 8 then
        return sub(rep(salt, ceil(8 / z)), 1, 8)
    end
    if z > 8 then
        return sub(salt, 1, 8)
    end

    return salt
end

local function get_cipher(self, key, salt)
    local mode = aes.cipher(self.size, self.mode)
    if not mode then
        return nil, "invalid cipher mode " .. self.mode ..  "(" .. self.size .. ")"
    end

    return aes:new(key, adjust_salt(salt), mode, self.hash, self.rounds)
end

local cipher = {}

cipher.__index = cipher

function cipher.new(session)
    local config = session.aes or defaults
    return setmetatable({
        size   = CIPHER_SIZES[tonumber(config.size, 10)] or defaults.size,
        mode   = CIPHER_MODES[config.mode]               or defaults.mode,
        hash   = HASHES[config.hash]                     or defaults.hash,
        rounds = tonumber(config.rounds,            10)  or defaults.rounds,
    }, cipher)
end

function cipher:encrypt(data, key, salt, _)
    local cip, err = get_cipher(self, key, salt)
    if not cip then
        return nil, err or "unable to aes encrypt data"
    end

    local encrypted_data
    encrypted_data, err = cip:encrypt(data)
    if not encrypted_data then
        return nil, err or "aes encryption failed"
    end

    if self.mode == "gcm" then
        return encrypted_data[1], nil, encrypted_data[2]
    end

    return encrypted_data
end

function cipher:decrypt(data, key, salt, _, tag)
    local cip, err = get_cipher(self, key, salt)
    if not cip then
        return nil, err or "unable to aes decrypt data"
    end

    local decrypted_data
    decrypted_data, err = cip:decrypt(data, tag)
    if not decrypted_data then
        return nil, err or "aes decryption failed"
    end

    if self.mode == "gcm" then
        return decrypted_data, nil, tag
    end

    return decrypted_data
end

return cipher