-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
--
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local log = require 'mc.logging'
local mdb_service = require 'mc.mdb.mdb_service'
local cjson = require 'cjson'
local os = require 'os'

-- 用户资源树
local ACCOUNT_MDB_PATH <const> = '/bmc/kepler/AccountService/Accounts'
local ACCOUNT_MDB_INTF <const> = 'bmc.kepler.AccountService.ManagerAccount'
local ACCOUNT_POLICY_MDB_INTF<const> = 'bmc.kepler.AccountService.AccountPolicy'
local LOCAL_ACCOUNT_POLICY_MDB_PATH<const> = '/bmc/kepler/AccountService/AccountPolicies/Local'
local OEM_ACCOUNT_POLICY_MDB_PATH<const> = '/bmc/kepler/AccountService/AccountPolicies/OemAccount'
-- 时间资源树
local TIME_MDB_PATH <const> = '/bmc/kepler/Managers'
local TIME_MDB_INTF <const> = 'bmc.kepler.Managers.Time'

local m = {}

---获取当前BMC系统的所有本地用户的用户名
function m.get_users_name()
    local users_name = {}
    local account_objs = mdb.get_sub_objects(bus, ACCOUNT_MDB_PATH, ACCOUNT_MDB_INTF)
    for _, account in pairs(account_objs) do
        if account.AccountType == "Local" then
            users_name[#users_name + 1] = account.UserName
        end
    end

    return users_name
end

function m.is_have_user_mgmt(privileges)
    for _, value in ipairs(privileges) do
        if value == 'UserMgmt' then
            return true
        end
    end

    return false
end

function m.get_account_visible_map()
    local local_account_policy = mdb.get_object(bus, LOCAL_ACCOUNT_POLICY_MDB_PATH, ACCOUNT_POLICY_MDB_INTF)
    local oem_account_policy = mdb.get_object(bus, OEM_ACCOUNT_POLICY_MDB_PATH, ACCOUNT_POLICY_MDB_INTF)
    local account_visible_map = {}
    account_visible_map['Local'] = local_account_policy.Visible
    account_visible_map['OEM'] = oem_account_policy.Visible
    return account_visible_map
end

---获取当前BMC系统的所有本地用户的数量
function m.get_count(Context)
    local count = 0

    local have_user_mgmt = m.is_have_user_mgmt(Context.Privilege)
    local account_visible_map = m.get_account_visible_map()
    if have_user_mgmt then
        local account_objs = mdb.get_sub_objects(bus, ACCOUNT_MDB_PATH, ACCOUNT_MDB_INTF)
        for _, account in pairs(account_objs) do
            if not account_visible_map[account.AccountType] then
                goto continue
            end
            local is_visible = account_visible_map[account.AccountType]
            if is_visible then
                count = count + 1
            end
            ::continue::
        end
    else
        count = 1
    end

    return count
end

---获取当前BMC系统的所有本地用户的资源树路径，按照用户Id属性升序排列
function m.get_paths(accounts_paths, context)
    local paths = {}

    local have_user_mgmt = m.is_have_user_mgmt(context.Privilege)
    local account_visible_map = m.get_account_visible_map()
    if have_user_mgmt then
        for _, path in ipairs(accounts_paths or {}) do
            local obj = mdb.get_object(bus, path, ACCOUNT_MDB_INTF)
            if not account_visible_map[obj.AccountType] then
                goto continue
            end
            local is_visible = account_visible_map[obj.AccountType]
            if is_visible then
                paths[#paths+1] = path
            end
            ::continue::
        end
    else
        paths = { ACCOUNT_MDB_PATH .. '/' .. context.AccountId }
    end
    local sort_rules = function(path1, path2)
        local id1 = tonumber(string.match(path1, '.*/(%d+)'))
        local id2 = tonumber(string.match(path2, '.*/(%d+)'))
        return id1 < id2
    end
    table.sort(paths, sort_rules)
    return paths
end

-- 用户的角色和角色Id的映射关系
local role_type_map = {
    Noaccess = 0,
    Commonuser = 2,
    Operator = 3,
    Administrator = 4,
    CustomRole1 = 5,
    CustomRole2 = 6,
    CustomRole3 = 7,
    CustomRole4 = 8,
    CustomRole5 = 9,
    CustomRole6 = 10,
    CustomRole7 = 11,
    CustomRole8 = 12,
    CustomRole9 = 13,
    CustomRole10 = 14,
    CustomRole11 = 15,
    CustomRole12 = 16,
    CustomRole13 = 17,
    CustomRole14 = 18,
    CustomRole15 = 19,
    CustomRole16 = 20
}

---根据用户的角色，获取用户的角色Id
function m.get_role_id(role)
    return role_type_map[role]
end

-- 根据用户的角色Id，获取用户的角色
function m.get_role(role_id)
    for key, value in pairs(role_type_map) do
        if role_id == value then
            return key
        end
    end

    return nil
end

-- 登录策略和登录策略Id的映射关系
local login_policy_map = {
    PromptPasswordReset = 1,
    ForcePasswordReset = 2
}

function m.get_login_policy_id(policy)
    return login_policy_map[policy]
end

function m.get_login_policy(policy_id)
    for key, value in pairs(login_policy_map) do
        if value == policy_id then
            return key
        end
    end

    return nil
end

-- SNMP鉴权算法和算法Id的映射关系
local snmp_authentication_protocols_map = {
    MD5 = 1,
    SHA = 2,
    SHA1 = 3,
    SHA256 = 4,
    SHA384 = 5,
    SHA512 = 6
}

function m.get_snmp_auth_protocol_id(protocol)
    return snmp_authentication_protocols_map[protocol]
end

function m.get_snmp_auth_protocol(protocol_id)
    for key, value in pairs(snmp_authentication_protocols_map) do
        if value == protocol_id then
            return key
        end
    end

    return nil
end

-- SNMP加密算法和算法Id的映射关系
local snmp_encryption_protocols_map = {
    DES = 1,
    AES = 2,
    AES256 = 3
}

function m.get_snmp_crypto_protocol_id(protocol)
    return snmp_encryption_protocols_map[protocol]
end

function m.get_snmp_crypto_protocol(protocol_id)
    for key, value in pairs(snmp_encryption_protocols_map) do
        if value == protocol_id then
            return key
        end
    end

    return nil
end

-- 登录规则URI格式
local login_rule_format = "/redfish/v1/Managers/%s#/Oem/{{OemIdentifier}}/LoginRule/%d"

local login_rule_map = {
    Rule1 = 1,
    Rule2 = 2,
    Rule3 = 3,
}

-- 把获取到的登录规则id转换成Redfish格式的登录规则
function m.get_login_rule(slot_id, login_rule_ids)
    local login_rules = {}

    for _, value in ipairs(login_rule_ids) do
        local login_rule = {}
        login_rule["@odata.id"] = string.format(login_rule_format, slot_id, login_rule_map[value])
        login_rules[#login_rules + 1] = login_rule
    end

    return login_rules
end

-- 获取用户密码的有效天数
function m.get_pwd_valid_days(pwd_valid_days)
    -- 密码有效天数为0xFFFFFFFF，代表密码一直有效，此时返回nil给Redfish接口
    if pwd_valid_days == 0xFFFFFFFF then
        return nil
    end

    return pwd_valid_days
end

-- Session资源树
local SESSION_MDB_PATH <const> = '/bmc/kepler/SessionService/Sessions'
local SESSION_MDB_INTF <const> = 'bmc.kepler.SessionService.Session'
-- AccountService资源树
local ACCOUNT_SERVICE_MDB_PATH <const> = '/bmc/kepler/AccountService'
local ACCOUNT_SERVICE_MDB_INTF <const> = 'bmc.kepler.AccountService'
-- SNMP上报告警事件资源树
local SNMP_MDB_PATH <const> = '/bmc/kepler/EventService/Subscriptions/Snmp'
local SNMP_MDB_INTF <const> = 'bmc.kepler.EventService.Subscriptions.Snmp'

-- 获取所有已使能的本地管理员用户的id
local function get_admin_account_ids()
    local account_ids = {}

    local account_objs = mdb.get_sub_objects(bus, ACCOUNT_MDB_PATH, ACCOUNT_MDB_INTF)
    for _, account in pairs(account_objs) do
        if account.AccountType == "Local" and account.RoleId == role_type_map.Administrator and
            account.Enabled == true then
            account_ids[#account_ids + 1] = account.Id
        end
    end

    return account_ids
end

-- 获取用户不可删除的原因
function m.get_del_disable_reason(account_service_obj, user_id)
    local emergency_account_id = account_service_obj.EmergencyLoginAccountId
    local snmp_v3_trap_policy = account_service_obj.SNMPv3TrapAccountLimitPolicy
    local snmp_v3_trap_account_id = account_service_obj.SNMPv3TrapAccountId
    -- 紧急逃生用户不允许删除
    if emergency_account_id == user_id then
        return 'EmergencyUser'
    -- snmpv3trap修改策略不为0时，trap用户不能被删除
    elseif snmp_v3_trap_policy ~= 0 and snmp_v3_trap_account_id == user_id then
        return 'TrapV3User'
    end
    -- 当前用户是唯一一个管理员用户，则不允许删除
    local account_ids = get_admin_account_ids()
    if #account_ids == 1 then -- 只有1个使能开启的管理员
        return 'UniqueAdminUser'
    end
    return cjson.null
end

local function set_max_and_min_password_age_days_proc(max_age, min_age, account_service_obj)
    local mdb_max_age = account_service_obj.MaxPasswordValidDays
    local mdb_min_age = account_service_obj.MinPasswordValidDays

    -- 最大和最小同时设，但不符合相互约束，直接抛错一个都不设置
    if max_age ~= 0 and min_age ~= 0 and max_age - min_age <= 10 then
        local custom_msg = require 'messages.custom'
        error(custom_msg.MinPwdAgeAndPwdValidityRestrictEachOther())
    end

    -- 新的min无问题，先改min
    if mdb_max_age ~= 0 and mdb_max_age - min_age > 10 then
        account_service_obj.MinPasswordValidDays = min_age
        account_service_obj.MaxPasswordValidDays = max_age
        return true
    end

    -- 新的max无问题，先改max
    if mdb_min_age ~= 0 and max_age - mdb_min_age > 10 then
        account_service_obj.MaxPasswordValidDays = max_age
        account_service_obj.MinPasswordValidDays = min_age
        return true
    end

    -- 老的max或新的min为0，直接先改min
    if mdb_max_age == 0 or min_age == 0 then
        account_service_obj.MinPasswordValidDays = min_age
        account_service_obj.MaxPasswordValidDays = max_age
        return true
    end

    -- 老的min或新的max为0，直接先改max
    if mdb_min_age == 0 or max_age == 0 then
        account_service_obj.MaxPasswordValidDays = max_age
        account_service_obj.MinPasswordValidDays = min_age
        return true
    end

    account_service_obj.MaxPasswordValidDays = max_age
    account_service_obj.MinPasswordValidDays = min_age
    return true
end

-- 判断密码最大有效期和最小有效期是否合法
function m.set_max_and_min_password_age_days(reqbody_oem, max_age_req)
    local account_service_obj = mdb.get_object(bus, ACCOUNT_SERVICE_MDB_PATH, ACCOUNT_SERVICE_MDB_INTF)
    local max_age, max_age_oem, min_age
    -- 请求体中不存在reqbody.oem则max_age_oem与min_age都为nil，否则reqbody.oem下两个有效期带了为其值，不带为nil
    if reqbody_oem then
        max_age_oem = reqbody_oem["{{OemIdentifier}}"].PasswordValidityDays
        min_age = reqbody_oem["{{OemIdentifier}}"].MinimumPasswordAgeDays
    end
    if max_age_req ~= nil then
        max_age = max_age_req
    elseif max_age_oem ~= nil then
        max_age = max_age_oem
    else
        max_age = nil
    end

    -- 若有nil（即未带），跳过正常以业务进行拦截判断
    if max_age == nil or min_age == nil then
        if max_age ~= nil then
            account_service_obj.MaxPasswordValidDays = max_age ~= cjson.null and max_age or 0
        end
        if min_age ~= nil then
            account_service_obj.MinPasswordValidDays = min_age ~= cjson.null and min_age or 0
        end
        return true
    end

    -- 若传null,转
    max_age = max_age ~= cjson.null and max_age or 0
    min_age = min_age ~= cjson.null and min_age or 0

    return set_max_and_min_password_age_days_proc(max_age, min_age, account_service_obj)
end

function m.is_valid_account_id(id)
    local number_id = tonumber(id)
    -- id为非数字字符组成的字符串时tonumber返回nil，直接校验失败
    if not number_id then
        log:error('the type of value(%s) for id is invalid', tostring(id))
        return false
    end
    if number_id < 2 or number_id > 17 then
        log:error('the id(%u) of account is invalid', number_id)
        return false
    end
    local path = '/bmc/kepler/AccountService/Accounts/' .. id
    local ok, rsp = pcall(mdb_service.is_valid_path, bus, path)
    if not ok or not rsp.Result then
        log:error('the id(%u) of account is invalid', number_id)
        return false
    end

    return true
end

-- 设置snmp鉴权算法时同时需要重置snmp加密密码与用户登录密码
function m.check_set_snmp_auth_protocol_parameter(password, snmp_priv_passwd)
    if not password or not snmp_priv_passwd then
        local err = custom_messages.ModifyAuthProtocolLackProp()
            err.RelatedProperties = {'#/Oem/{{OemIdentifier}}/SnmpV3AuthProtocol',
                '#/Oem/{{OemIdentifier}}/SnmpV3PrivPasswd', '#/Password'}
            error(err)
    end
    return true
end

function m.get_snmp_community(ctx, rw_community, ro_community)
    -- 仅管理员可见SNMP明文信息
    local have_user_mgmt = m.is_have_user_mgmt(ctx.Privilege)
    if have_user_mgmt then
        return {rw_community, ro_community}
    else
        return {nil, nil}
    end
end

function m.split_tls_enabled(tls_array)
    local map = {
        ['Tls1.0'] = false,
        ['Tls1.1'] = false,
        ['Tls1.2'] = false,
        ['Tls1.3'] = false
    }

    for _, v in pairs(tls_array) do
        map[v] = true
    end

    return {map['Tls1.0'], map['Tls1.1'], map['Tls1.2'], map['Tls1.3']}
end

function m.get_tls_version(states)
    local res = {}

    local map = {
        ['TLS_1_0_Enabled'] = 'Tls1.0',
        ['TLS_1_1_Enabled'] = 'Tls1.1',
        ['TLS_1_2_Enabled'] = 'Tls1.2',
        ['TLS_1_3_Enabled'] = 'Tls1.3'
    }

    for k, v in pairs(map) do
        if states[k] then
            table.insert(res, v)
        end
    end
    table.sort(res)

    return res
end

local function table_find(tab, val)
    for i, v in pairs(tab) do
        if v == val then
            return i
        end
    end
    return nil
end

function m.check_login_interface_condition(interfaces, cur_interfaces, pwd, set_pwd_err1, set_pwd_err2)
    local err = custom_messages.AccountMustResetPassword()
    err.RelatedProperties = {'#/Oem/{{OemIdentifier}}/LoginInterface'}
    if table_find(interfaces, "IPMI") == nil or table_find(cur_interfaces, "IPMI") ~= nil then
        return interfaces
    end
    if not pwd or #pwd == 0 then
        error(err)
    end
    if #set_pwd_err1 > 0 or #set_pwd_err2 > 0 then
        error(err)
    end
    return interfaces
end

function m.get_formatted_last_login_time(time_stamp)
    -- 新用户LastLoginTime默认值为0xffffffff
    if time_stamp == 0xffffffff then
        return ""
    end
    local date_time_offset
    local formatted_time = os.date('!%Y-%m-%dT%H:%M:%S', time_stamp)
    local time_objs = mdb.get_sub_objects(bus, TIME_MDB_PATH, TIME_MDB_INTF)
    for _, time_obj in pairs(time_objs) do
        if time_obj.DateTimeLocalOffset ~= nil and time_obj.DateTimeLocalOffset ~= '' then
            date_time_offset = time_obj.DateTimeLocalOffset
            break
        end
    end
    formatted_time = formatted_time .. date_time_offset
    return formatted_time
end

return m
