-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at:
--         http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.
local utils = require 'mc.utils'
local log = require 'mc.logging'
local utils_file = require 'utils.file'
local cjson = require 'cjson'
local mdb_service = require 'mc.mdb.mdb_service'
local error = require 'mc.error'

local m = {}

-- 用户资源树
local ACCOUNT_MDB_PATH <const> = '/bmc/kepler/AccountService/Accounts'
local ACCOUNT_MDB_INTF <const> = 'bmc.kepler.AccountService.ManagerAccount'

-- 用户证书资源树
local ACCOUNT_CERT_MDB_PATH<const> = '/bmc/kepler/AccountService/MultiFactorAuth/ClientCertificate/Certificates/%d'
local ACCOUNT_CERT_MDB_INFT<const> = 'bmc.kepler.CertificateService.Certificate.Account'
local CERT_MDB_INTF<const> = 'bmc.kepler.CertificateService.Certificate'

-- 获取本地有效用户(使能的、管理员角色的用户)
function m.get_available_users(user_name, role_id, privileges)
    local users_name = {}

    local has_user_mgmt_privilege = false
    for _, privilege in pairs(privileges) do
        if privilege == "UserMgmt" then
            has_user_mgmt_privilege = true
        end
    end

    -- 无安全管理权限的管理员用户
    if not has_user_mgmt_privilege and role_id == "Administrator" then
        users_name[#users_name + 1] = user_name
        return users_name
    end

    -- 无安全管理权限的非管理员用户
    if not has_user_mgmt_privilege and role_id ~= "Administrator" then
        return users_name
    end

    -- 有安全管理权限
    local account_objs = mdb.get_sub_objects(bus, ACCOUNT_MDB_PATH, ACCOUNT_MDB_INTF)
    for _, account in pairs(account_objs) do
        if account.AccountType == "Local" and account.Enabled == true and account.RoleId == 4 then
            users_name[#users_name + 1] = account.UserName
        end
    end

    return users_name
end

---判断用户或者角色是否包含某个权限
---@param privileges table 用户或者角色的权限集合
---@param privilege_name string 要判断的权限
function m.is_has_privilege(privileges, privilege_name)
    local is_has_privilege = false
    for _, privilege in ipairs(privileges) do
        if privilege == privilege_name then
            is_has_privilege = true
            break
        end
    end

    return is_has_privilege
end

-- 角色及其Id的对应关系
local role_map = {
    NoAccess = 0,
    CommonUser = 2,
    Operator = 3,
    Administrator = 4,
    CustomRole1 = 5,
    CustomRole2 = 6,
    CustomRole3 = 7,
    CustomRole4 = 8
}

---通过角色id获取角色名称
---@param id number 角色id
function m.get_role_name(id)
    for role_name, role_id in pairs(role_map) do
        if role_id == tonumber(id) then
            return role_name
        end
    end

    return nil
end

function m.get_role_id(role_name)
    return role_map[role_name]
end

-- 用户角色资源树
local ROLE_MDB_PATH <const> = "/bmc/kepler/AccountService/Roles/"
local ROLE_MDB_INTF <const> = "bmc.kepler.AccountService.Role"

---获取角色的权限
---@param role_name string 角色名称
local function get_role_privilege(role_name)
    local role_id = m.get_role_id(role_name)
    local role_path = ROLE_MDB_PATH .. role_id
    local role_obj = mdb.get_object(bus, role_path, ROLE_MDB_INTF)
    local role_privileges = role_obj.RolePrivilege

    local privilege = {}
    privilege.ID = role_id
    privilege.Name = role_name
    privilege.DiagnosisEnabled = m.is_has_privilege(role_privileges, 'DiagnoseMgmt')
    privilege.QueryEnabled = m.is_has_privilege(role_privileges, 'ReadOnly')
    privilege.ConfigureSelfEnabled = m.is_has_privilege(role_privileges, 'ConfigureSelf')
    privilege.UserConfigEnabled = m.is_has_privilege(role_privileges, 'UserMgmt')
    privilege.VMMEnabled = m.is_has_privilege(role_privileges, 'VMMMgmt')
    privilege.BasicConfigEnabled = m.is_has_privilege(role_privileges, 'BasicSetting')
    privilege.RemoteControlEnabled = m.is_has_privilege(role_privileges, 'KVMMgmt')
    privilege.SecurityConfigEnabled = m.is_has_privilege(role_privileges, 'SecurityMgmt')
    privilege.PowerControlEnabled = m.is_has_privilege(role_privileges, 'PowerMgmt')

    return privilege
end

---获取所有角色的权限
---@param role_name string 当前用户的角色
---@param privileges table 当前用户的所有权限
function m.get_all_roles_privilege(role_name, privileges)
    local all_roles_privilege = { Count = 0, Privileges = {} }

    local is_has_privilege = m.is_has_privilege(privileges, "UserMgmt")

    -- 没有用户管理权限的用户只能获取该用户本身的角色权限
    if not is_has_privilege then
        all_roles_privilege.Privileges[#all_roles_privilege.Privileges + 1] = get_role_privilege(role_name)
        all_roles_privilege.Count = #all_roles_privilege.Privileges
        return all_roles_privilege
    end

    -- 有用户管理权限的用户可以获取所有角色的权限
    local role_names = { 'Administrator', 'Operator', 'CommonUser', 'NoAccess', 'CustomRole1', 'CustomRole2',
        'CustomRole3', 'CustomRole4' }
    for _, name in ipairs(role_names) do
        all_roles_privilege.Privileges[#all_roles_privilege.Privileges + 1] = get_role_privilege(name)
    end

    all_roles_privilege.Count = #all_roles_privilege.Privileges

    return all_roles_privilege
end

-- 权限及其Id的对应关系
local privilege_map = {
    UserMgmt = 0,
    BasicSetting = 1,
    KVMMgmt = 2,
    VMMMgmt = 3,
    SecurityMgmt = 4,
    PowerMgmt = 5,
    DiagnoseMgmt = 6,
    ReadOnly = 7,
    ConfigureSelf = 8
}

function m.get_privilege_id(privilege_name)
    return privilege_map[privilege_name]
end

-- 获取SNMPv3的有效用户
function m.get_snmp_trap_available_users(user_name, auth_type, privileges)
    local users_name = {}

    -- 无用户管理权限
    if not m.is_has_privilege(privileges, 'UserMgmt') then
        if auth_type ~= 0 or #privileges == 0 then
            return {}
        end

        user_name[#user_name + 1] = user_name

        return users_name
    end

    local account_objs = mdb.get_sub_objects(bus, ACCOUNT_MDB_PATH, ACCOUNT_MDB_INTF)
    for _, account in pairs(account_objs) do
        if account.AccountType == "Local" and account.Enabled == true and account.RoleId ~= 0 and
            account.PasswordExpiration ~= 0 then
            users_name[#users_name + 1] = account.UserName
        end
    end

    return users_name
end

-- AccountService资源树
local ACCOUNT_SERVICE_MDB_PATH <const> = '/bmc/kepler/AccountService'
local ACCOUNT_SERVICE_MDB_INTF <const> = 'bmc.kepler.AccountService'

-- 判断密码最大有效期和最小有效期是否合法
function m.set_max_and_min_password_age_days(max_age, min_age)
    local account_service_obj = mdb.get_object(bus, ACCOUNT_SERVICE_MDB_PATH, ACCOUNT_SERVICE_MDB_INTF)

    -- 若有nil（即未带），跳过正常以业务进行拦截判断
    if max_age == nil or min_age == nil then
        if max_age ~= nil then
            account_service_obj.MaxPasswordValidDays = max_age
        end
        if min_age ~= nil then
            account_service_obj.MinPasswordValidDays = min_age
        end
        return true
    end

    local mdb_max_age = account_service_obj.MaxPasswordValidDays
    local mdb_min_age = account_service_obj.MinPasswordValidDays


    -- 最大和最小同时设，但不符合相互约束，直接抛错一个都不设置
    if max_age ~= 0 and min_age ~= 0 and max_age - min_age <= 10 then
        local custom_msg = require 'messages.custom'
        error(custom_msg.MinPwdAgeAndPwdValidityRestrictEachOther())
    end

    -- 新的min无问题，先改min
    if mdb_max_age ~= 0 and mdb_max_age - min_age > 10 then
        account_service_obj.MinPasswordValidDays = min_age
        account_service_obj.MaxPasswordValidDays = max_age
        return true
    end

    -- 新的max无问题，先改max
    if mdb_min_age ~= 0 and max_age - mdb_min_age > 10 then
        account_service_obj.MaxPasswordValidDays = max_age
        account_service_obj.MinPasswordValidDays = min_age
        return true
    end

    -- 老的max或新的min为0，直接先改min
    if mdb_max_age == 0 or min_age == 0 then
        account_service_obj.MinPasswordValidDays = min_age
        account_service_obj.MaxPasswordValidDays = max_age
        return true
    end

    -- 老的min或新的max为0，直接先改max
    if mdb_min_age == 0 or max_age == 0 then
        account_service_obj.MaxPasswordValidDays = max_age
        account_service_obj.MinPasswordValidDays = min_age
        return true
    end

    account_service_obj.MaxPasswordValidDays = max_age
    account_service_obj.MinPasswordValidDays = min_age
    return true
end

-- 获取所有本地用户的双因素信息
function m.get_all_local_users_mutual_auth_info()
    local user_infos = {}
    local user_id_table = {}
    local account_objs = mdb.get_sub_objects(bus, ACCOUNT_MDB_PATH, ACCOUNT_MDB_INTF)
    for _, account in pairs(account_objs) do
        if account.AccountType == "Local" then
            local user_info = {}
            user_id_table[#user_id_table+1] = account.Id
            user_info.UserId = account.Id
            user_info.RoleID = m.get_role_name(account.RoleId)
            user_info.UserName = account.UserName
            local account_id = account.Id
            local path = string.format(ACCOUNT_CERT_MDB_PATH, account_id)
            local ok, account_cert_obj = pcall(mdb.get_object, bus, path, ACCOUNT_CERT_MDB_INFT)
            if not ok then
                user_info.RevokedState = ""
                user_info.SerialNumber = ""
                user_info.RevokedDate = cjson.null
                user_info.RootCertUploadedState = false
                user_info.FingerPrint = ""
            else
                user_info.RevokedState = account_cert_obj.RevokedState
                user_info.SerialNumber = account_cert_obj.SerialNumber
                user_info.RevokedDate = account_cert_obj.RevokedDate
                user_info.RootCertUploadedState = account_cert_obj.RootCertUploadedState
                user_info.Fingerprint = mdb.get_object(bus, path, CERT_MDB_INTF).Fingerprint
            end
            user_infos[#user_infos+1] = user_info
        end
    end
    -- 通过user_id进行排序
    table.sort(user_infos, function(a, b) return a.UserId < b.UserId end)
    return user_infos
end

-- 获取所有本地用户的双因素信息
function m.get_all_local_users_number()
    local local_user_number = 0
    local account_objs = mdb.get_sub_objects(bus, ACCOUNT_MDB_PATH, ACCOUNT_MDB_INTF)
    for _, account in pairs(account_objs) do
        if account.AccountType == "Local" then
            local_user_number = local_user_number + 1
        end
    end
    return local_user_number
end

-- 判断用户是否存在
function m.is_valid_account_id(id)
    local number_id = tonumber(id)
    -- id为非数字字符组成的字符串时tonumber返回nil，直接校验失败
    if not number_id then
        log:error('the type of value(%s) for id is invalid', tostring(id))
        return false
    end
    if number_id < 2 or number_id > 17 then
        log:error('the id(%u) of account is invalid', number_id)
        return false
    end
    local path = '/bmc/kepler/AccountService/Accounts/' .. id
    local ok, rsp = pcall(mdb_service.is_valid_path, bus, path)
    if not ok or not rsp.Result then
        log:error('the id(%u) of account is invalid', number_id)
        return false
    end
    return true
end

function m.get_login_rule(rules)
    -- 后台登录规则与Web接口返回的映射关系
    local login_rule_map = {
        Rule1 = '/UI/Rest/AccessMgnt/LoginRule/1',
        Rule2 = '/UI/Rest/AccessMgnt/LoginRule/2',
        Rule3 = '/UI/Rest/AccessMgnt/LoginRule/3'
    }

    table.sort(rules)

    local login_rules = {}
    for _, rule in ipairs(rules) do
        login_rules[#login_rules + 1] = login_rule_map[rule]
    end

    return login_rules
end

-- SNMP鉴权算法和算法Id的映射关系
local snmp_authentication_protocols_map = {
    MD5 = 1,
    SHA = 2,
    SHA1 = 3,
    SHA256 = 4,
    SHA384 = 5,
    SHA512 = 6
}
-- SNMP加密算法和算法Id的映射关系
local snmp_encryption_protocols_map = {
    DES = 1,
    AES = 2,
    AES256 = 3
}

-- algorithm_type非1表示加密算法，1表示鉴权算法
function m.get_snmp_protocol_id(protocol, algorithm_type)
    return (algorithm_type == 1 and snmp_authentication_protocols_map[protocol] or
    snmp_encryption_protocols_map[protocol])
end

function m.get_snmp_protocol(protocol_id, algorithm_type)
    local snmp_protocols_map = (algorithm_type == 1 and snmp_authentication_protocols_map or
    snmp_encryption_protocols_map)
    for key, value in pairs(snmp_protocols_map) do
        if value == protocol_id then
            return key
        end
    end
    return nil
end

function m.get_del_disable_reason(account_service, user_id)
    local emergency_account_id = account_service.EmergencyLoginAccountId
    local snmp_v3_trap_policy = account_service.SNMPv3TrapAccountLimitPolicy
    local snmp_v3_trap_account_id = account_service.SNMPv3TrapAccountId
    if emergency_account_id and emergency_account_id == user_id then
        return 'EmergencyUser'
    -- snmpv3trap修改策略不为0时，trap用户不能被删除
    elseif snmp_v3_trap_policy ~= 0 and snmp_v3_trap_account_id == user_id then
        return 'TrapV3User'
    end
    return cjson.null
end

function m.get_paths(accounts_paths, context)
    local paths = {}

    local have_user_mgmt = m.is_has_privilege(context.Privilege, 'UserMgmt')
    if have_user_mgmt then
        for _, path in ipairs(accounts_paths or {}) do
            local obj = mdb.get_object(bus, path, ACCOUNT_MDB_INTF)
            if obj.AccountType == 'Local' then
                paths[#paths + 1] = path
            end
        end
    else
        paths = { ACCOUNT_MDB_PATH .. '/' .. context.AccountId }
    end
    local sort_rules = function(path1, path2)
        local id1 = tonumber(string.match(path1, '.*/(%d+)'))
        local id2 = tonumber(string.match(path2, '.*/(%d+)'))
        return id1 < id2
    end
    table.sort(paths, sort_rules)
    return paths
end

local login_interface_map = {
    Invalid = 0,
    Web = 1,
    SNMP = 2,
    IPMI = 4,
    SSH = 8,
    SFTP = 16,
    Local = 64,
    Redfish = 128
}

function m.get_interface_number(interface)
    if not interface then
        return {}
    end
    local ret = {}
    for _, v in pairs(interface) do
        ret[#ret + 1] = login_interface_map[v]
    end
    return ret
end

-- 设置snmp鉴权算法时同时需要重置snmp加密密码与用户登录密码
function m.check_set_snmp_auth_protocol_parameter(password, snmp_priv_passwd)
    if not password or not snmp_priv_passwd then
        local err = custom_messages.ModifyAuthProtocolLackProp()
        err.RelatedProperties = {'SnmpV3AuthProtocol', 'SnmpV3PrivPasswd', 'Password'}
        error(err)
    end
    return true
end

function m.is_have_user_mgmt(privileges)
    for _, value in ipairs(privileges) do
        if value == 'UserMgmt' then
            return true
        end
    end

    return false
end

function m.get_snmp_community(ctx, rw_community, ro_community)
    -- 仅管理员可见SNMP明文信息
    local have_user_mgmt = m.is_have_user_mgmt(ctx.Privilege)
    if have_user_mgmt then
        return {rw_community, ro_community}
    else
        return {nil, nil}
    end
end

function m.split_tls_enabled(tls_array)
    local map = {
        ['TLS1.0'] = false,
        ['TLS1.1'] = false,
        ['TLS1.2'] = false,
        ['TLS1.3'] = false
    }

    for _, v in pairs(tls_array) do
        map[v] = true
    end

    return {map['TLS1.0'], map['TLS1.1'], map['TLS1.2'], map['TLS1.3']}
end

function m.get_tls_version(states)
    local res = {}

    local map = {
        ['TLS_1_0_Enabled'] = 'TLS1.0',
        ['TLS_1_1_Enabled'] = 'TLS1.1',
        ['TLS_1_2_Enabled'] = 'TLS1.2',
        ['TLS_1_3_Enabled'] = 'TLS1.3'
    }

    for k, v in pairs(map) do
        if states[k] then
            table.insert(res, v)
        end
    end

    table.sort(res)
    return res
end

function m.split_tls_capabilities(tls_capabilities)
    local res = {}

    for _, st in pairs(tls_capabilities) do
        -- st 结构为 [TlsName:IsSupported]
        if st.IsSupported then
            table.insert(res, st.TlsName)
        end
    end

    table.sort(res)
    return res
end

return m
