-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
--
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local utils = require 'mc.utils'
local log = require 'mc.logging'
local cjson = require 'cjson'
local mdb_service = require 'mc.mdb.mdb_service'
local custom_msg = require 'messages.custom'

local MDB_GET_PATH<const> = 'GetPath'
local PCIECARD_INTERFACE = 'bmc.kepler.Systems.PCIeDevices.PCIeCard'
local DPUCARD_INTERFACE = 'bmc.kepler.Systems.DPUCard'

-- GetPath函数返回的无效路径
local INVALID_PATH<const> = ''

local m = {}

function m.is_valid_set_nmi_path(slot_id)

    if slot_id == nil then
        log:error('is_valid_set_nmi_path slot_id is nil')
        error(custom_msg.SdiCardNotSupport())
    end

    local filter = {SlotID = tonumber(slot_id)}
    local ok, rsp = pcall(mdb_service.get_path, bus, PCIECARD_INTERFACE, cjson.encode(filter), false)
    if not ok then
        log:error('Verify SlotID(%s) failed, err(%s)', slot_id, rsp.message)
        error(custom_msg.SdiCardNotSupport())
    end

    if rsp.Path == INVALID_PATH then
        log:error('Get invalid path')
        error(custom_msg.SdiCardNotSupport())
    end

    return true
end

local function get_table_len(tbl)
    if type(tbl) ~= 'table' then
        return
    end
    local count = 0
    for _, _ in pairs(tbl) do
        count = count + 1
    end
    return count
end

-- 判断是否为DPU卡
local function is_dpu_card(path)
    local ok, rsp = pcall(mdb_service.get_object, bus, path, {DPUCARD_INTERFACE})
    if not ok then
        return false
    end
    local count = get_table_len(rsp.Object)
    if not count or count == 0 then
        return false
    end
    return true
end

function m.is_valid_dpu_card(slot_id)
    if not slot_id then
        log:error('slot_id is nil')
        error(custom_msg.InvalidValue('nil', 'SlotId'))
    end

    local filter = {SlotID = slot_id}
    local ok, rsp = pcall(mdb_service.get_path, bus, PCIECARD_INTERFACE, cjson.encode(filter), false)
    if not ok then
        log:error('Verify SlotID(%s) failed, err(%s)', slot_id, rsp.message)
        error(custom_msg.InvalidValue(tostring(slot_id), 'SlotId'))
    end

    if rsp.Path == INVALID_PATH then
        log:error('Get invalid path')
        error(custom_msg.InvalidValue(tostring(slot_id), 'SlotId'))
    end

    if not is_dpu_card(rsp.Path) then
        log:error('PCIeCard%s is not dpucard', slot_id)
        error(custom_msg.InvalidValue(tostring(slot_id), 'SlotId'))
    end

    return true
end

-- bdf用于兼容老实现，保证至少有旧的内容
function m.get_multi_bdfs(device_path, bdf)
    local sub_objects = mdb.get_sub_objects(bus, device_path .. '/PCIeFunctions/', 'bmc.kepler.Systems.PCIeDevice.PCIeFunction')
    local name_bdfs = {}
    local name_list = {}
    local name

    for _, v in pairs(sub_objects) do
        name = v.RelatedProcessorId .. '_' .. v.BusNumber .. v.DeviceNumber .. v.FunctionNumber
        name_bdfs[name] = string.format("%04x:%02x:%02x.%x",
                    v.SegmentNumber, v.BusNumber, v.DeviceNumber, v.FunctionNumber)
        name_list[#name_list + 1] = name
    end

    local bdfs = {}

    table.sort(name_list)

    for _, func_name in ipairs(name_list) do
        if name_bdfs[func_name] == "0000:ff:ff.ff" then
            bdfs[#bdfs + 1] = cjson.null
        else
            bdfs[#bdfs + 1] = name_bdfs[func_name]
        end
    end

    if #bdfs == 0 then
        bdfs[1] = bdf
    end

    return bdfs
end

function m.get_multi_rootbdfs(device_path, rootbdf)
    local sub_objects = mdb.get_sub_objects(bus, device_path .. '/PCIeFunctions/', 'bmc.kepler.Systems.PCIeDevice.PCIeFunction')
    local name_rootbdfs = {}
    local name_list = {}
    local name

    for _, v in pairs(sub_objects) do
        name = v.RelatedProcessorId .. '_' .. v.RootBusNumber .. v.RootDeviceNumber .. v.RootFunctionNumber
        name_rootbdfs[name] = string.format("%04x:%02x:%02x.%x",
                    v.SegmentNumber, v.RootBusNumber, v.RootDeviceNumber, v.RootFunctionNumber)
        name_list[#name_list + 1] = name
    end

    local rootbdfs = {}

    table.sort(name_list)

    for _, func_name in ipairs(name_list) do
        if name_rootbdfs[func_name] == "0000:ff:ff.ff" then
            rootbdfs[#rootbdfs + 1] = cjson.null
        else
            rootbdfs[#rootbdfs + 1] = name_rootbdfs[func_name]
        end
    end

    if #rootbdfs == 0 then
        rootbdfs[1] = rootbdf
    end

    return rootbdfs
end

return m
