-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
--
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.
local cjson = require 'cjson'
local m = {}
local MEM_MDB_INTF <const> = "bmc.kepler.Systems.Memory"
local log = require 'mc.logging'
local BIT_ECC_DEFAULT <const> = '{"AggregateTotalCount": 0, "Count": 0, "Info": null}'
local BIT_ISOLATED_DEFAULT <const> ='{"Count": 0, "Info": null}'
local PCIECARD_INTERFACE<const> = 'bmc.kepler.Systems.PCIeDevices.PCIeCard'
local PROPERTY_INTERFACE<const> = 'bmc.kepler.Object.Properties'
local INVALID_PATH<const> = ''

function m.format_identification(processor_id, processor_id_str)
    if processor_id_str then
        return processor_id_str
    end
    if not processor_id then
        return nil
    end
    local str = string.format('%016X', processor_id)
    local identification = ''
    for i = #str, 1, -2 do
        identification = identification .. string.sub(str, i - 1, i)
        if i ~= 2 then
            identification = identification .. '-'
        end
    end
    return identification
end

-- 根据npu对象路径判断其是否为NPUCard
local function get_pcie_card(npu_path, systemid)
    local npu_obj = mdb.get_object(bus, npu_path, PROPERTY_INTERFACE)
    local npu_position = npu_obj.ObjectIdentifier[4]
    local pcie_card_path = "/bmc/kepler/Systems/" .. systemid .. "/PCIeDevices/PCIeCards"
    local pcie_cards = mdb.get_sub_objects(bus, pcie_card_path, PROPERTY_INTERFACE)
    local target_pcie_card
    local ok
    for card_path, card in pairs(pcie_cards) do 
        if card.ObjectIdentifier[4] == npu_position then
            ok, target_pcie_card = pcall(mdb.get_object, bus, card_path, PCIECARD_INTERFACE)
            if ok then 
                return target_pcie_card
            end
        end
    end
end

function m.get_device_physical_id(cpu_paths, systemid)
    local ids = {}
    local id
    for _, path in ipairs(cpu_paths) do
        id = m.get_id(path, systemid)
        if id then
            table.insert(ids, id)
            id = nil
        end
    end
    return ids
end

function m.get_id(path, systemid)
    if string.match(path, "CPU") then
        local obj = mdb.get_object(bus, path, 'bmc.kepler.Systems.Processor.CPU')
        return obj.PhysicalId
    elseif string.match(path, "GPU") then
        local obj = mdb.get_object(bus, path, 'bmc.kepler.Systems.Processor.GPU')
        return 'Gpu' .. obj.Slot
    elseif string.match(path, "NPU") then
        local obj = mdb.get_object(bus, path, 'bmc.kepler.Systems.Processor')
        if get_pcie_card(path, systemid) then
            return 'Npu' .. (obj.Id + 1)//2 .. '-' .. ((obj.Id + 1) % 2 + 1)
        else
            return 'Npu' .. obj.Id
        end
    end
end

function m.get_single_device_physical_id(path)
    if string.match(path, "CPU") then
        local obj = mdb.get_object(bus, path, 'bmc.kepler.Systems.Processor.CPU')
        return obj.PhysicalId
    elseif string.match(path, "GPU") then
        local obj = mdb.get_object(bus, path, 'bmc.kepler.Systems.Processor.GPU')
        return 'Gpu' .. obj.Slot
    elseif string.match(path, "NPU") then
        local obj = mdb.get_object(bus, path, 'bmc.kepler.Systems.Processor')
        return 'Npu' .. obj.Id
    else
        return ''
    end
end

function m.format_nvlink_info(nvlink_info)
    for i = 1, #nvlink_info do
        for j = 1, #nvlink_info[i] do
            if nvlink_info[i][j] == 32768 then
                nvlink_info[i][j] = "NA"
            end
        end
    end
    return nvlink_info
end

function m.get_gpu_nvlink_info(nvlink_info)
    local param_idx = {
        [1] = "NvLinkStatus",
        [2] = "ReplayErrorCount",
        [3] = "RecoveryErrorCount",
        [4] = "FlitCRCErrorCount",
        [5] = "DataCRCErrorCount"
    }
    local info = cjson.json_object_new_array()

    for i = 1, #nvlink_info do
        local element = cjson.json_object_new_object()
        for j = 1, #nvlink_info[i] do
            if nvlink_info[i][j] == 32768 then
                nvlink_info[i][j] = "NA"
            end
            element[param_idx[j]] = nvlink_info[i][j]
        end
        info[i] = element
    end
    return info
end

function m.get_hbm_path(mem_paths, cpuid)
    local res_path = ''
    for _, path in ipairs(mem_paths) do
        local ok, obj = pcall(mdb.get_object, bus, path, MEM_MDB_INTF)
        if ok and obj.CpuId == cpuid and string.find(obj.MemoryDeviceType, 'DRAM') then
            return path
        end
    end
    return res_path
end

function m.get_npu_ecc(ecc_info, info_type)
    if ecc_info then
        return cjson.json_object_ordered_decode(ecc_info)
    end
    log:error('Get %s failed.', info_type)
    if info_type == 'SingleBitEcc' or info_type == 'MultiBitEcc' then
        return cjson.json_object_ordered_decode(BIT_ECC_DEFAULT)
    end
    return cjson.json_object_ordered_decode(BIT_ISOLATED_DEFAULT)
end

function m.get_pcie_npu_device_locator(npu_path, systemid)
    local pcie_card = get_pcie_card(npu_path, systemid)
    if not pcie_card then
        return cjson.null
    end
    return 'PCIe Card ' .. pcie_card.SlotID .. ' (NPU)'
end

function m.get_pcie_npu_position(npu_path, systemid)
    local pcie_card = get_pcie_card(npu_path, systemid)
    if not pcie_card then
        return cjson.null
    end
    local container_slot = string.match(pcie_card.Position, '%d+')
    return 'PCIe Riser' .. container_slot
end

function m.is_valid_npu_port_path(npu_path)
    local path = npu_path .. '/Ports/'
    local ok, rsp = pcall(mdb_service.is_valid_path, bus, path)
    if not ok then
        log:error('is_valid_npu_port_path failed, err(%s)', rsp.message)
        error(rsp)
    end

    return rsp.Result
end

function m.get_npu_port_path(npu_path, port_id)
    if not npu_path or not string.find(npu_path, 'NPU') then
        return ''
    end
    local port_path = npu_path .. '/Ports/'
    local port_list = mdb.get_sub_objects(bus, port_path,
        'bmc.kepler.Systems.Processor.Port')
    for path, obj in pairs(port_list) do
        if not port_id or obj.Id == tonumber(port_id) then
            return path
        end
    end
    return ''
end

function m.get_npu_port_propertys(npu_port_paths, property)
    local port_ids = {}
    for _, path in ipairs(npu_port_paths) do
        local obj = mdb.get_object(bus, path, 'bmc.kepler.Systems.Processor.Port')
        table.insert(port_ids, obj[property])
    end
    table.sort(port_ids, function (a, b)
        return a < b
    end)
    return port_ids
end

function m.get_npu_processor_id(url_processor_id)
    local Id = string.lower(url_processor_id)
    -- url 里的processorId 格式是npu(%d+)或者npu(%d+)-(%d+)
    local slot, index = string.match(Id, 'npu(%d)-(%d)')
    if slot and index then
        slot = tonumber(slot)
        index = tonumber(index)
        if slot and index and index >= 1 and index <= 2 then
            return 2 * (slot - 1) + index
        else
            return nil
        end
    elseif string.match(Id, 'npu(%d+)')  then
        local a_single = string.match(Id, 'npu(%d+)')
        return tonumber(a_single)
    end
    return nil
end

function m.get_related_processor(processor_ports)
    local processorids = {}
    for _, processor_port in pairs(processor_ports) do
        local Id = string.lower(processor_port[1])
        local processorid = 0
        if string.match(Id, 'npu(%d+)%-(%d+)') then
            local a, b = string.match(Id, 'npu(%d+)%-(%d+)')
            processorid = 2 * tonumber(a) + tonumber(b) - 2
        elseif string.match(Id, 'npu(%d+)') then
            local a_single = string.match(Id, 'npu(%d+)')
            processorid =  tonumber(a_single) or 255
        end
        if  processorid ~= 255 then
            table.insert(processorids, processorid)
        end
    end
    return processorids
end

local PORT_URL_PATTERN<const> = '/redfish/v1/Chassis/%s/NetworkAdapters/%s/Ports/%s'
function m.get_network_ports(npu_port_id, npu_name, systemid)
    local network_ports_list = {}
    -- 先查找符合预期的网卡Path
    local network_adapter_path = string.format('/bmc/kepler/Systems/%s/NetworkAdapters', systemid)
    local network_adapter_list = mdb.get_sub_objects(bus, network_adapter_path, 'bmc.kepler.Systems.NetworkAdapter')
    local port_list
    local port_obj
    for path, network_adapter in pairs(network_adapter_list) do
        -- 根据Model 过滤目标网卡
        if network_adapter.Model and string.match(network_adapter.Model, 'NPU') then
            port_list = mdb.get_sub_objects(bus, path .. '/Ports', 'bmc.kepler.Systems.NetworkPort.RelatedItems')
            -- 根据网口RelatedItems接口下ProcessorPorts属性，1 NpuName，2 Processor的Port口 3 Processor的IO die
            for port_path, port in pairs(port_list) do
                for _, processor_port in pairs(port.ProcessorPorts) do
                    if processor_port[1] == npu_name and processor_port[2] == npu_port_id then
                        port_obj = mdb.get_object(bus, port_path, 'bmc.kepler.Systems.NetworkPort')
                        network_ports_list[#network_ports_list + 1] =  {
                            ['@odata.id'] = string.format(PORT_URL_PATTERN, systemid, network_adapter.NodeId, port_obj.PortID + 1)
                        }
                    end
                end
            end
        end
    end
    table.sort(network_ports_list, function(a, b)
        if string.match(a['@odata.id'], '/([^/]+)$') and string.match(b['@odata.id'], '/([^/]+)$') then
            local a_single= string.match(a['@odata.id'], '/([^/]+)$')
            local b_single = string.match(b['@odata.id'], '/([^/]+)$')
            return tonumber(a_single) < tonumber(b_single)
        end
        return a['@odata.id'] < b['@odata.id']
    end)
    return network_ports_list
end

return m