-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
--
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.
local log = require 'mc.logging'
local context = require 'mc.context'
local cjson = require 'cjson'
local mdb_service = require 'mc.mdb.mdb_service'
local file_utils = require 'utils.file'
local core_utils = require 'utils.core'
local vos = require 'utils.vos'

-- 证书资源树路径
local CERTS_MDB_PATH<const> = '/bmc/kepler/Managers/%d/Certificates'
local CERT_MDB_PATH_FMT<const> = '/bmc/kepler/Managers/%d/Certificates/%d'
local SSL_CERT_MDB_PATH_FMT<const> = '/bmc/kepler/Managers/%d/NetworkProtocol/HTTPS/Certificates/%d'
local CERT_SERVICE_PATH<const> = '/bmc/kepler/CertificateService'
local CERT_SERVICE_INTF<const> = 'bmc.kepler.CertificateService'
local CERT_CSR<const> = '/data/trust/cert/server.csr'  -- SSL证书CSR生成导出固定路径
local CSR_KEY_PAIR<const> = '/data/trust/cert/csr_key.pem'  -- CSR公私钥对
local CA_ODATA_ID<const> = '^/redfish/v1/Managers/1/Certificates/(%d+)$'
local SSL_ODATA_ID<const> = '^/redfish/v1/Managers/1/NetworkProtocol/HTTPS/Certificates/(%d+)$'

local m = {}

---获取当前BMC系统的所有CA证书的资源树路径，按照Id属性升序排列
function m.get_ca_certificate_paths(cert_list)
    local paths = {}
    local ids = {}
    for _, path in pairs(cert_list) do
        local cert_id = string.match(path, "Certificates/(%d+)")
        cert_id = tonumber(cert_id) and tonumber(cert_id) or cert_id
        ids[cert_id] = {['@odata.id'] = string.gsub(path, '/bmc/kepler', '/redfish/v1')}
    end
    for _, value in pairs(ids) do
        paths[#paths+1] = value
    end
    return paths
end

---获取当前BMC系统的所有SSL证书的资源树路径，按照Id属性升序排列
function m.get_ssl_certificate_paths(cert_list)
    local paths = {}
    local ids = {}
    for _, path in pairs(cert_list) do
        local cert_id = string.match(path, "Certificates/(%d+)")
        cert_id = tonumber(cert_id) and tonumber(cert_id) or cert_id
        ids[cert_id] = {['@odata.id'] = string.gsub(path, '/bmc/kepler', '/redfish/v1')}
    end
    for _, value in pairs(ids) do
        paths[#paths+1] = value
    end
    return paths
end

--- 是否有效的CA证书id
---@param id any
---@return boolean
function m.is_valid_cert_path(manager_id, id)
    local id_number = tonumber(id)
    if not id_number then
        log:error('the id(%s) of cert is invalid', id)
        return false
    end
    if id_number < 1 or id_number > 32 then
        log:error('the id(%u) of cert is invalid', id_number)
        return false
    end
    local path = string.format(CERT_MDB_PATH_FMT, manager_id, id)
    local ok, rsp = pcall(mdb_service.is_valid_path, bus, path)
    if not ok or not rsp.Result then
        log:error('the id(%u) of cert is invalid', id_number)
        return false
    end
    return true
end

--- 是否有效的SSL证书id
---@param id any
---@return boolean
function m.is_valid_ssl_cert_path(manager_id, id)
    local id_number = tonumber(id)
    if not id_number then
        log:error('the id(%s) of cert is invalid', id)
        return false
    end
    local path = string.format(SSL_CERT_MDB_PATH_FMT, manager_id, id_number)
    local ok, rsp = pcall(mdb_service.is_valid_path, bus, path)
    if not ok or not rsp.Result then
        log:error('the id(%u) of ssl cert is invalid', id_number)
        return false
    end
    return true
end

-- 获取证书的key_usage扩展字段
local function get_key_usage(key_usage_str)
    local res = {}
    if type(key_usage_str) == 'table' then
        res = key_usage_str
    else	
        local delimiter = ', '
        -- 使用 Lua 的 string.gmatch 函数进行字符串的拆分
        for match in string.gmatch(key_usage_str, '([^' .. delimiter .. ']+)') do
            res[#res + 1] = match
        end
    end
    local key_usage_table = {
        DigitalSignature = 'Digital Signature',
        NonRepudiation = 'Non Repudiation',
        KeyEncipherment = 'Key Encipherment',
        DataEncipherment = 'Data Encipherment',
        KeyAgreement = 'Key Agreement',
        KeyCertSign = 'Certificate Sign',
        CRLSigning = 'CRL Sign',
        EncipherOnly = 'Encipher Only',
        DecipherOnly = 'Decipher Only'
    }
    local rsp = ''
    for i = 1, #res do
        if i ~= #res then
            rsp = rsp .. key_usage_table[res[i]] .. ', '
        else
            rsp = rsp .. key_usage_table[res[i]]
        end
    end
    return rsp
end

-- 查询SSL证书资源信息的X509CertificateInformation字段数据获取
---@param cert_usage_type int 0表示ca, 1表示ssl
---@param cert_id int ca的范围为1-32，ssl只能为1
---@return unknown
function m.get_certificate_information(cert_usage_type, cert_id)
    local cert_usage_type = tonumber(cert_usage_type)
    local cert_id = tonumber(cert_id)
    local cert_service_obj = mdb.get_object(bus, CERT_SERVICE_PATH, CERT_SERVICE_INTF)
    local rsp = cert_service_obj:GetCertChainInfo(context.new(), cert_usage_type, cert_id)
    rsp = cjson.decode(rsp.CertInfo)
    local tmp_server_cert = cjson.json_object_new_object() -- cjson有序接口，生成的JSON对象数据保持与添加的顺序一致
    tmp_server_cert.ServerCert = cjson.json_object_new_object()

    tmp_server_cert.ServerCert.Subject = rsp.ServerCert.Subject
    tmp_server_cert.ServerCert.Issuer = rsp.ServerCert.Issuer
    tmp_server_cert.ServerCert.ValidNotBefore = rsp.ServerCert.ValidNotBefore
    tmp_server_cert.ServerCert.ValidNotAfter = rsp.ServerCert.ValidNotAfter
    tmp_server_cert.ServerCert.SerialNumber = rsp.ServerCert.SerialNumber
    tmp_server_cert.ServerCert.SignatureAlgorithm = rsp.ServerCert.FingerprintHashAlgorithm
    tmp_server_cert.ServerCert.KeyUsage = get_key_usage(rsp.ServerCert.KeyUsage)
    tmp_server_cert.ServerCert.PublicKeyLengthBits = rsp.ServerCert.KeyLength
    if type(rsp.IntermediateCert) == 'table' and #rsp.IntermediateCert >= 1 then
        tmp_server_cert.IntermediateCert = cjson.json_object_new_array()
        for _, intermediate_cert in pairs(rsp.IntermediateCert) do
            local tmp_cert_data = cjson.json_object_new_object()
            tmp_cert_data.Subject = intermediate_cert.Subject
            tmp_cert_data.Issuer = intermediate_cert.Issuer
            tmp_cert_data.ValidNotBefore = intermediate_cert.ValidNotBefore
            tmp_cert_data.ValidNotAfter = intermediate_cert.ValidNotAfter
            tmp_cert_data.SerialNumber = intermediate_cert.SerialNumber
            tmp_cert_data.SignatureAlgorithm = intermediate_cert.FingerprintHashAlgorithm
            tmp_cert_data.KeyUsage = get_key_usage(intermediate_cert.KeyUsage)
            tmp_cert_data.PublicKeyLengthBits = intermediate_cert.KeyLength
            tmp_server_cert.IntermediateCert[#tmp_server_cert.IntermediateCert + 1] = tmp_cert_data
        end
    end
    if type(rsp.RootCert) == 'table' then
        tmp_server_cert.RootCert = cjson.json_object_new_object()
        tmp_server_cert.RootCert.Subject = rsp.RootCert.Subject
        tmp_server_cert.RootCert.Issuer = rsp.RootCert.Issuer
        tmp_server_cert.RootCert.ValidNotBefore = rsp.RootCert.ValidNotBefore
        tmp_server_cert.RootCert.ValidNotAfter = rsp.RootCert.ValidNotAfter
        tmp_server_cert.RootCert.SerialNumber = rsp.RootCert.SerialNumber
        tmp_server_cert.RootCert.SignatureAlgorithm = rsp.RootCert.FingerprintHashAlgorithm
        tmp_server_cert.RootCert.KeyUsage = get_key_usage(rsp.RootCert.KeyUsage)
        tmp_server_cert.RootCert.PublicKeyLengthBits = rsp.RootCert.KeyLength
    end
    return tmp_server_cert
end

local function package_cert_object(cert_name_type, need_fingerprint, crl_print_type, source_cert)
    local cert = cjson.json_object_new_object()
    if cert_name_type == 1 then
        cert.IssueBy = source_cert.Issuer
        cert.IssueTo = source_cert.Subject
        cert.ValidFrom = source_cert.ValidNotBefore
        cert.ValidTo = source_cert.ValidNotAfter
    else
        cert.Issuer = source_cert.Issuer
        cert.Subject = source_cert.Subject
        cert.ValidNotBefore = source_cert.ValidNotBefore
        cert.ValidNotAfter = source_cert.ValidNotAfter
    end

    cert.SerialNumber = source_cert.SerialNumber
    cert.SignatureAlgorithm = source_cert.FingerprintHashAlgorithm
    cert.KeyUsage = get_key_usage(source_cert.KeyUsage)
    cert.PublicKeyLengthBits = source_cert.KeyLength

    if need_fingerprint then
        cert.FingerPrint = source_cert.Fingerprint
    end
    if crl_print_type == 1 then
        cert.IsImportCrl = source_cert.IsImportCrl or false        -- 证书链设置为false
        cert.CrlValidFrom = source_cert.CrlValidFrom or cjson.null -- 证书链设置为null
        cert.CrlValidTo = source_cert.CrlValidTo or cjson.null     -- 证书链设置为null
    end
    return cert
end

local function package_cert_array(cert_name_type, need_fingerprint, source_certs)
    local cert_list = cjson.json_object_new_array()
    for _, intermediate_cert in pairs(source_certs) do
        local tmp_cert = package_cert_object(cert_name_type, need_fingerprint, false, intermediate_cert)
        cert_list[#cert_list + 1] = tmp_cert
    end
    return cert_list
end

---@function 创建返回证书信息json对象
---@param    is_cert_chain    boolean
---@param    cert_name_type   0:Issuer、Subject、ValidNotBefore、ValidNotAfter 1：IssueBy、IssueTo、ValidFrom、ValidTo
---@param    need_fingerprint boolean
---@param    crl_print_type   number
local function create_cert_object(is_cert_chain, cert_name_type, need_fingerprint, crl_print_type, rsp)
    if is_cert_chain then
        local cert_server = cjson.json_object_new_object()
        cert_server.ServerCert = package_cert_object(cert_name_type, need_fingerprint, crl_print_type, rsp.ServerCert)
        if rsp.IntermediateCert and type(rsp.IntermediateCert) == 'table' and #rsp.IntermediateCert >= 1 then
            cert_server.IntermediateCert = package_cert_array(cert_name_type, need_fingerprint, rsp.IntermediateCert)
        end
        if rsp.RootCert then
            cert_server.RootCert = package_cert_object(cert_name_type, need_fingerprint, false, rsp.RootCert)
        end

        return cert_server
    else
        return package_cert_object(cert_name_type, need_fingerprint, crl_print_type, rsp.ServerCert)
    end
end

-- @function get_cert_info
-- @param cert_name_type     0:Issuer、Subject、ValidNotBefore、ValidNotAfter 1：IssueBy、IssueTo、ValidFrom、ValidTo
-- @param is_cert_chain      是否打印证书链形式
-- @param need_fingerprint   是否打印FingerPrint信息
-- @param crl_print_type     0:不打印 1:打印在ServerCert内 2:独立打印在外部
-- @param cert_list          全部CA证书列表
function m.get_ca_cert_info(cert_name_type, is_cert_chain, need_fingerprint, crl_print_type, cert_list)
    local cert_service_obj = mdb.get_object(bus, CERT_SERVICE_PATH, CERT_SERVICE_INTF)
    local  max_cert_id = 0
    for _, cert_path in ipairs(cert_list) do
        local cert_id = tonumber(string.match(cert_path, "%d+$")) -- 匹配以数字结尾的子串
        max_cert_id = max_cert_id > cert_id and max_cert_id or cert_id
    end
    -- 针对没有证书的场景直接返回null
    if max_cert_id == 0 then
        return cjson.null
    end

    local rsp = cert_service_obj:GetCertChainInfo(context.new(), 0, max_cert_id)
    rsp = cjson.decode(rsp.CertInfo)
    local resp = cjson.json_object_new_array()
    -- 证书返回json结构体
    resp[1] = create_cert_object(is_cert_chain, cert_name_type, need_fingerprint, crl_print_type, rsp)
    -- 对应证书的CRL信息，且在证书外打印
    resp[2] = rsp.ServerCert.IsImportCrl or false
    resp[3] = cjson.null
    resp[4] = cjson.null

    if rsp.ServerCert.IsImportCrl and crl_print_type == 2 then
        resp[3] = rsp.ServerCert.CrlValidFrom
        resp[4] = rsp.ServerCert.CrlValidTo
    end
    return resp
end

-- @function get_ca_cert_id
-- @param target_cert_id 目标证书ID
-- @param cert_list      全部CA证书列表
-- @return max_cert_id   返回指定证书ID或最后一本证书的ID
function m.get_ca_cert_id(target_cert_id, cert_list)
    local cert_id
    if target_cert_id ~= nil then
        for _, cert_path in ipairs(cert_list) do
            cert_id = tonumber(string.match(cert_path, "%d+$")) -- 匹配以数字结尾的子串
            if cert_id == target_cert_id then
                return target_cert_id
            end
        end
        error(custom_messages.RootCANotExists())
    end
    local max_cert_id = 0
    for _, cert_path in ipairs(cert_list) do
        cert_id = tonumber(string.match(cert_path, "%d+$")) -- 匹配以数字结尾的子串
        max_cert_id = max_cert_id > cert_id and max_cert_id or cert_id
    end
    -- 针对没有证书的场景直接抛错
    if max_cert_id == 0 then
        error(custom_messages.RootCANotExists())
    end
    return max_cert_id
end

local utils_core = require 'utils.core'
function m.check_certificate_valid(type, certificate, param)
    local pattern = '^((https|sftp|nfs|cifs|scp)://.{1,1000}|/tmp/.{1,246})$'
    if type == 'text' or not type then
        return true
    end
    if type == 'URI' and utils_core.g_regex_match(pattern, certificate) then
        return true
    end
    local err = base_messages.PropertyValueFormatError('******', param)
    err.RelatedProperties = {'#/' .. param}
    error(err)
end

function m.is_import_permitted(type, content, file_type, perproty_name, result)
    if type ~= 'URI' then
        return true
    end

    local pattern_collection = {
        ['pub'] = "^((https|sftp|nfs|cifs|scp)://.{1,1000}|/tmp/.{1,246})\\.pub$",
        ['cert'] = "^((https|sftp|nfs|cifs|scp)://.{1,1000}|/tmp/.{1,246})\\.(der|crt|cer|cert|pem|p12|pfx)$",
        ['crl'] = "^((https|sftp|nfs|cifs|scp)://.{1,1000}|/tmp/.{1,246})\\.crl$"
    }

    local error_collection = {
        ['pub'] = custom_messages.PublicKeyImportFailed(),
        ['cert'] = custom_messages.CertImportFailed(),
        ['crl'] = custom_messages.CrlImportFailed()
    }

    if not core_utils.g_regex_match(pattern_collection[file_type], content) then
        error(base_messages.PropertyValueFormatError("******", perproty_name))
    end

    if content:sub(1,1) ~= '/' then
        return true
    end

    if not core_utils.is_file(content) then
        error(error_collection[file_type])
    end

    if file_utils.check_real_path_s(content, "/tmp") ~= 0 then
        error(error_collection[file_type])
    end

    if result then
        return true
    end
    error(custom_messages.NoPrivilegeToOperateSpecifiedFile())
end

function m.check_certificate_collection_valid(certificate_collection)
    -- 当前仅支持HttpsCert证书csr生成
    if certificate_collection ~=
        "/redfish/v1/Managers/1/SecurityService/HttpsCert" then
        error(custom_messages.ActionParameterValueInvalid(certificate_collection, 'CertificateCollection'))
    end
    return true
end

function m.get_cert_string_by_path(path)
    if not core_utils.is_file(path) or
        file_utils.check_real_path_s(path) ~= 0 then
        log:error('certificate file path is invalid')
        return ""
    end

    local file = file_utils.open_s(path, 'r')
    if file == nil then
        log:error('open certificate file failed')
        return ""
    end
    local content = file:read("*a")
    file:close()
    if content == nil or string.len(content) == 0 then
        log:error('read certificate file failed')
        return ""
    end
    return content
end

local identifier_map = {
    CommonName = 'CN',
    OrganizationalUnit = 'OU',
    Organization = 'O',
    City = 'L',
    State = 'S',
    Country = 'C'
}

function m.get_cert_issuer_and_subject(issuer_str, subject_str, schema_version)
    local identifier = cjson.json_object_new_object()
    -- 兼容历史版本，Issuer、Subject为字符串格式
    if schema_version == '1.0.0' then
        identifier.Issuer = issuer_str
        identifier.Subject = subject_str
        return identifier
    end
    local issuer = cjson.json_object_new_object()
    local subject = cjson.json_object_new_object()
    for key, value in pairs(identifier_map) do
        issuer[key] = issuer_str:match(value .. "=([^,]*)")
        subject[key] = subject_str:match(value .. "=([^,]*)")
    end
    identifier.Issuer = issuer
    identifier.Subject = subject
    return identifier
end

function m.get_cert_subject_with_san(issuer_str, subject_str, schema_version, san)
    local identifier = m.get_cert_issuer_and_subject(issuer_str, subject_str, schema_version, san)
    if identifier.Subject ~= nil then
        identifier.Subject.AlternativeNames = san
    end
    return identifier
end

local KEYUSAGE_ENUM_MAP = {
    [0] = "DigitalSignature",
    [1] = "NonRepudiation",
    [2] = "KeyEncipherment",
    [3] = "DataEncipherment",
    [4] = "KeyAgreement",
    [5] = "KeyCertSign",
    [6] = "CRLSigning",
    [7] = "EncipherOnly",
    [8] = "DecipherOnly",
    [9] = "ServerAuthentication",
    [10] = "ClientAuthentication",
    [11] = "CodeSigning",
    [12] = "EmailProtection",
    [13] = "Timestamping",
    [14] = "OCSPSigning"
}

function m.get_formatted_key_usage(keyusage_num_tab, schema_version)
    -- 兼容历史版本，KeyUsage为字符串格式
    if schema_version == '1.0.0' then
        local key_usage_str = {}
        for _, enum_num in pairs(keyusage_num_tab) do
            key_usage_str[#key_usage_str + 1] = KEYUSAGE_ENUM_MAP[enum_num]
        end
        return get_key_usage(key_usage_str)
    end

    local key_usage = cjson.json_object_new_array()
    for _, enum_num in pairs(keyusage_num_tab) do
        key_usage[#key_usage + 1] = KEYUSAGE_ENUM_MAP[enum_num]
    end
    return key_usage
end

function m.get_certificate_csr()
    -- 公私钥不存在表示此csr生成证书已导入
    if not vos.get_file_accessible(CSR_KEY_PAIR) then
        return nil
    end
    local csr_file = file_utils.open_s(CERT_CSR, 'r')
    if not csr_file then
        return nil
    end
    local content = csr_file:read('*a')
    csr_file:close()
    return content
end

function m.get_certificate_usage_type_and_id(certificate_uri, certificate_type, content)
    local cert_odata_id = certificate_uri["@odata.id"]
     local CERT_STR<const> = "-----BEGIN CERTIFICATE-----"
    local cert_id = string.match(cert_odata_id, CA_ODATA_ID)
    if cert_id then
        if certificate_type ~= "PEM" and certificate_type ~= "PEMchain" then
            error(base_messages.PropertyValueNotInList(certificate_type, "CA certificate type"))
        end
        local count = 0
        for _ in string.gmatch(content, CERT_STR) do
            count = count + 1
        end
        if (certificate_type == "PEM" and count ~= 1) or (certificate_type == "PEMchain" and count <= 1) then
            error(base_messages.PropertyValueNotInList(count, "certificate count"))
        end

        if not m.is_valid_cert_path(1, cert_id) then
            error(base_messages.ResourceMissingAtURI(cert_odata_id))
        end
        
        return {UsageType = 0, Id = cert_id}
    end

    cert_id = string.match(cert_odata_id, SSL_ODATA_ID)
    -- 导入SSL证书id只允许为1
    if cert_id == "1" then
        if not m.is_valid_ssl_cert_path(1, cert_id) then
            error(base_messages.ResourceMissingAtURI(cert_odata_id))
        end
        return {UsageType = 1, Id = 1}
    end

    -- 非法URI
    error(base_messages.ResourceMissingAtURI(cert_odata_id))
end

return m
