-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
--
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local log = require 'mc.logging'
local cjson = require 'cjson'

-- 用户资源树
local TLS_CIPHERSUITS_PATH_0<const> = '/bmc/kepler/Managers/1/Security/TlsConfig/CipherSuits/0'
local TLS_CIPHERSUITS_PATH_1<const> = '/bmc/kepler/Managers/1/Security/TlsConfig/CipherSuits/1'
local TLS_CIPHERSUITS_PATH_3<const> = '/bmc/kepler/Managers/1/Security/TlsConfig/CipherSuits/3'
local TLS_CIPHERSUITS_PATH_4<const> = '/bmc/kepler/Managers/1/Security/TlsConfig/CipherSuits/4'
local TLS_CIPHERSUITS_PATH_5<const> = '/bmc/kepler/Managers/1/Security/TlsConfig/CipherSuits/5'
local TLS_CIPHERSUITS_PATH_7<const> = '/bmc/kepler/Managers/1/Security/TlsConfig/CipherSuits/7'
local TLS_CIPHERSUITS_PATH_11<const> = '/bmc/kepler/Managers/1/Security/TlsConfig/CipherSuits/11'
local TLS_CIPHERSUITS_PATH_13<const> = '/bmc/kepler/Managers/1/Security/TlsConfig/CipherSuits/13'
local TLS_CIPHERSUITS_PATH_14<const> = '/bmc/kepler/Managers/1/Security/TlsConfig/CipherSuits/14'
local TLS_CIPHERSUITS_PATH_15<const> = '/bmc/kepler/Managers/1/Security/TlsConfig/CipherSuits/15'
local TLS_CIPHERSUITS_PATH_18<const> = '/bmc/kepler/Managers/1/Security/TlsConfig/CipherSuits/18'
local TLS_CIPHERSUITS_PATH_19<const> = '/bmc/kepler/Managers/1/Security/TlsConfig/CipherSuits/19'
local TLS_CIPHERSUITS_PATH_20<const> = '/bmc/kepler/Managers/1/Security/TlsConfig/CipherSuits/20'
local TLS_CIPHERSUITS_PATH_21<const> = '/bmc/kepler/Managers/1/Security/TlsConfig/CipherSuits/21'
local TLS_CIPHERSUITS_PATH_22<const> = '/bmc/kepler/Managers/1/Security/TlsConfig/CipherSuits/22'
local TLS_CIPHERSUITS_PATH_23<const> = '/bmc/kepler/Managers/1/Security/TlsConfig/CipherSuits/23'
local TLS_CIPHERSUITS_MDB_INTF<const> = 'bmc.kepler.Managers.Security.TlsConfig.CipherSuit'
local SSL_SIGNATURE_ALGORITHMS_PATH<const> = '/bmc/kepler/CertificateService'
local SSL_SIGNATURE_ALGORITHMS_INTF<const> = 'bmc.kepler.CertificateService'

local m = {}

local ssl_chipersuites_mdb_table = {
    ["ECDHE_RSA_AES256_GCM_SHA384"]     = TLS_CIPHERSUITS_PATH_0,
    ["ECDHE_ECDSA_AES256_GCM_SHA384"] = TLS_CIPHERSUITS_PATH_1,
    ["DHE_RSA_AES256_GCM_SHA384"] = TLS_CIPHERSUITS_PATH_3,
    ["ECDHE_RSA_AES128_GCM_SHA256"] = TLS_CIPHERSUITS_PATH_4,
    ["ECDHE_ECDSA_AES128_GCM_SHA256"] = TLS_CIPHERSUITS_PATH_5,
    ["DHE_RSA_AES128_GCM_SHA256"] = TLS_CIPHERSUITS_PATH_7,
    ["ECDHE_RSA_CHACHA20_POLY1305"] = TLS_CIPHERSUITS_PATH_11,
    ["DHE_RSA_AES128_CCM"] = TLS_CIPHERSUITS_PATH_13,
    ["DHE_RSA_AES256_CCM"] = TLS_CIPHERSUITS_PATH_14,
    ["DHE_RSA_CHACHA20_POLY1305"] = TLS_CIPHERSUITS_PATH_15,
    ["ECDHE_ECDSA_AES128_CCM"] = TLS_CIPHERSUITS_PATH_18,
    ["ECDHE_ECDSA_AES256_CCM"] = TLS_CIPHERSUITS_PATH_19,
    ["ECDHE_ECDSA_CHACHA20_POLY1305"] = TLS_CIPHERSUITS_PATH_20,
    ["TLS_CHACHA20_POLY1305_SHA256"] = TLS_CIPHERSUITS_PATH_21,
    ["TLS_AES_256_GCM_SHA384"] = TLS_CIPHERSUITS_PATH_22,
    ["TLS_AES_128_GCM_SHA256"] = TLS_CIPHERSUITS_PATH_23
}

function m.get_enable_tls_cipher_suits()
    local res = {}
    local chiper_suite_obj
    for k, v in pairs(ssl_chipersuites_mdb_table) do
        chiper_suite_obj = mdb.get_object(bus, v, TLS_CIPHERSUITS_MDB_INTF)
        if chiper_suite_obj.Enabled then
            table.insert(res, k)
        end
    end
    return res
end

function m.get_disable_tls_cipher_suits()
    local res = {}
    local chiper_suite_obj
    for k, v in pairs(ssl_chipersuites_mdb_table) do
        chiper_suite_obj = mdb.get_object(bus, v, TLS_CIPHERSUITS_MDB_INTF)
        if not chiper_suite_obj.Enabled then
            table.insert(res, k)
        end
    end
    if #res == 0 then
        return {"NONE"}
    end
    return res
end

function m.get_allowed_ssl_signature_algorithms()
    local res = {}
    local obj = mdb.get_object(bus, SSL_SIGNATURE_ALGORITHMS_PATH, SSL_SIGNATURE_ALGORITHMS_INTF)
    if obj.SSLCertAlgorithm == 0 then
        table.insert(res, 'RSA')
    else
        table.insert(res, 'ECC')
    end
    return res
end

function m.get_denied_ssl_signature_algorithms()
    local res = {}
    local obj = mdb.get_object(bus, SSL_SIGNATURE_ALGORITHMS_PATH, SSL_SIGNATURE_ALGORITHMS_INTF)
    if obj.SSLCertAlgorithm == 0 then
        table.insert(res, 'ECC')
    else
        table.insert(res, 'RSA')
    end
    if #res == 0 then
        return {"NONE"}
    end
    return res
end

local function split_array(allowed, denied)
    local seen = {}
    for _, v in pairs(denied) do
        seen[v] = true
    end

    for k, v in pairs(allowed) do
        if seen[v] then
            table.remove(allowed, k)
        end
    end
    return allowed, denied
end

local function is_all_array(arr)
    if #arr == 0 or (#arr == 1 and arr[1] == "ALL") then
        return true
    end
    return false
end

local function is_array_null(arr)
    if #arr == 1 and arr[1] == "NONE" then
        return true
    end
    return false
end

local function split_allowed_and_denied(allowed, denied, all_allowed)
    -- 先做归一处理
    -- 如果为ALL或者空数组，则配置成全量配置
    -- 如果为NONE，则配置为空，表示什么都不配置
    if is_all_array(allowed) then
        allowed = all_allowed
    elseif is_array_null(allowed) then
        allowed = {}
    end
    
    if is_all_array(denied) then
        denied = all_allowed
    elseif is_array_null(denied) then
        denied = {}
    end
    allowed, denied = split_array(allowed, denied)
    return allowed, denied
end
local tls_version_map = {
    ['TLS1.0'] = '1.0',
    ['TLS1.1'] = '1.1',
    ['TLS1.2'] = '1.2',
    ['TLS1.3'] = '1.3',
}

local function get_supported_tls(supported)
    local supported_map = {}
    for _, v in pairs(supported) do
        supported_map[v.TlsName] = v.IsSupported
    end

    local res = {}
    for k, v in pairs(supported_map) do
        if v then
            table.insert(res, tls_version_map[k])
        end
    end
    return res
end

local function check_input_supported(input_table, supported)
    if #input_table == 1 and (input_table[1] == "ALL" or input_table[1] == "NONE") then
        return
    end
    local seen = {}
    for _, value in pairs(supported) do
        seen[value] = true
    end
    for key, value in pairs(input_table) do
        if not seen[value] then
            error(base_messages.PropertyValueNotInList(value, 'Allowed or Denied Versions'))
        end
    end
end

function m.set_tls_allowed_denied_version(allowed, denied, supported)
    -- 基本功能要求，不允许支持的配置为空
    if is_array_null(allowed) or is_all_array (denied) then
        log:error("allowed version being none or denied version being all is denied")
        error(base_messages.PropertyValueNotInList('NONE or ALL', 'Allowed or Denied Versions'))
    end
    local tls_version = get_supported_tls(supported)
    -- 检查allowed和denied中必须都是supported
    check_input_supported(allowed, tls_version)
    check_input_supported(denied, tls_version)
    local allowed_res, denied_res = split_allowed_and_denied(allowed, denied, tls_version)
    -- 设置allowed与denied
    local res = {["1.0"] = false, ["1.1"] = false, ["1.2"] = false, ["1.3"] = false}
    for _, v in pairs(allowed_res) do
        res[v] = true
    end
    for _, v in pairs(denied_res) do
        res[v] = false
    end
    return {res["1.0"], res["1.1"], res["1.2"], res["1.3"]}
end

local cipher_suites_all = {
    "ECDHE_RSA_AES256_GCM_SHA384",
    "ECDHE_ECDSA_AES256_GCM_SHA384",
    "DHE_RSA_AES256_GCM_SHA384",
    "ECDHE_RSA_AES128_GCM_SHA256",
    "ECDHE_ECDSA_AES128_GCM_SHA256",
    "DHE_RSA_AES128_GCM_SHA256",
    "ECDHE_RSA_CHACHA20_POLY1305",
    "DHE_RSA_AES128_CCM",
    "DHE_RSA_AES256_CCM",
    "DHE_RSA_CHACHA20_POLY1305",
    "ECDHE_ECDSA_AES128_CCM",
    "ECDHE_ECDSA_AES256_CCM",
    "ECDHE_ECDSA_CHACHA20_POLY1305",
    "TLS_CHACHA20_POLY1305_SHA256",
    "TLS_AES_256_GCM_SHA384",
    "TLS_AES_128_GCM_SHA256"
}

local ssl_chipersuites_mdb_table_num = {
    ["ECDHE_RSA_AES256_GCM_SHA384"]     = 0,
    ["ECDHE_ECDSA_AES256_GCM_SHA384"] = 1,
    ["DHE_RSA_AES256_GCM_SHA384"] = 3,
    ["ECDHE_RSA_AES128_GCM_SHA256"] = 4,
    ["ECDHE_ECDSA_AES128_GCM_SHA256"] = 5,
    ["DHE_RSA_AES128_GCM_SHA256"] = 7,
    ["ECDHE_RSA_CHACHA20_POLY1305"] = 11,
    ["DHE_RSA_AES128_CCM"] = 13,
    ["DHE_RSA_AES256_CCM"] = 14,
    ["DHE_RSA_CHACHA20_POLY1305"] = 15,
    ["ECDHE_ECDSA_AES128_CCM"] = 18,
    ["ECDHE_ECDSA_AES256_CCM"] = 19,
    ["ECDHE_ECDSA_CHACHA20_POLY1305"] = 20,
    ["TLS_CHACHA20_POLY1305_SHA256"] = 21,
    ["TLS_AES_256_GCM_SHA384"] = 22,
    ["TLS_AES_128_GCM_SHA256"] = 23
}

function m.set_tls_allowed_denied_cipherSuites(allowed, denied)
    -- 基本功能要求，不允许支持的配置为空
    if is_array_null(allowed) or is_all_array (denied) then
        log:error("allowed CipherSuites being none or denied CipherSuites being all is denied")
        error(base_messages.PropertyValueNotInList('NONE or ALL', 'Allowed or Denied CipherSuites'))
    end
    local allowed_res, denied_res = split_allowed_and_denied(allowed, denied, cipher_suites_all)
    local res = cjson.json_object_new_object()
    res.suits = cjson.json_object_new_array()
    res.enabled = cjson.json_object_new_array()
    for _, v in pairs(allowed_res) do
        res.suits[#res.suits + 1] = ssl_chipersuites_mdb_table_num[v]
        res.enabled[#res.enabled + 1] = true
    end
    for _, v in pairs(denied_res) do
        res.suits[#res.suits + 1] = ssl_chipersuites_mdb_table_num[v]
        res.enabled[#res.enabled + 1] = false
    end
    return res
end

local ssl_signature_all = {
    'RSA',
    'ECC'
}

function m.set_tls_allowed_denied_signatureAlgorithms(allowed, denied)
    -- 基本功能要求，不允许支持的配置为空
    if is_array_null(allowed) or is_all_array (denied) then
        log:error("allowed SignatureAlogrithms being none or denied SignatureAlogrithms being all is denied")
        error(base_messages.PropertyValueNotInList('NONE or ALL', 'Allowed or Denied SignatureAlogrithms'))
    end
    local allowed_res, denied_res = split_allowed_and_denied(allowed, denied, ssl_signature_all)
    -- 需要从返回的结果中选出要使能的算法
    if #allowed_res == 0 and #denied_res == 0 then
        log:error("allowed SignatureAlogrithms being none or denied SignatureAlogrithms being all is denied")
        error(base_messages.PropertyValueNotInList('NONE or ALL', 'Allowed or Denied SignatureAlogrithms'))
    end
    if (#allowed_res ~= 0 and #allowed_res ~= 1) or (#allowed_res == 0 and #denied_res ~= #ssl_signature_all - 1) then
        log:error("allowed SignatureAlogrithms being none or denied SignatureAlogrithms being all is denied")
        error(base_messages.PropertyValueNotInList('NONE or ALL', 'Allowed or Denied SignatureAlogrithms'))
    end
    if #allowed_res == 1 then
        if allowed_res[1] == 'RSA' then
            return 0
        end
        if allowed_res[1] == 'ECC' then
            return 1
        end
    end
    if #allowed_res == 0 and #denied_res ~= (#ssl_signature_all - 1) then
        local temp = {['ECC'] = false, ['RSA'] = false}
        for _, v in pairs(denied_res) do
            temp[v] = true
        end
        for k, v in pairs(temp) do
            if not v then
                if k == 'ECC' then
                    return 0
                end
                if k == 'RSA' then
                    return 1
                end
            end
        end
    end
end

return m