-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local log = require 'mc.logging'
local c_object_manage = require 'mc.orm.object_manage'
local pmu_cmd = require 'imu'

local card_resource_update_bdf = {}

local function get_pcie_vid_did_info(pcie_info)
    local payload = {
        system_id = pcie_info.system_id,
        is_local = pcie_info.is_local,
        cpu_id = pcie_info.cpu_id,
        bus_num = pcie_info.bus_num,
        device_num = pcie_info.device_num,
        function_num = pcie_info.function_num,
        address = pmu_cmd.pci_info_address_list.VID_DID_INFO_ADDR,
        read_length = 4
    }
    local bus = c_object_manage.get_instance().bus
    local info = pmu_cmd.get_info_from_pmu(bus, payload)
    if info == nil or #info < 4 then
        return
    end
    local device_id = ((info[4] & 0xff) << 8) + (info[3] & 0xff)
    local vendor_id = ((info[2] & 0xff) << 8) + (info[1] & 0xff)
    if vendor_id == 0xffff and device_id == 0xffff then
        return
    end
    return vendor_id, device_id
end

local function get_pcie_sub_vid_did_info(pcie_info)
    local payload = {
        system_id = pcie_info.system_id,
        is_local = pcie_info.is_local,
        cpu_id = pcie_info.cpu_id,
        bus_num = pcie_info.bus_num,
        device_num = pcie_info.device_num,
        function_num = pcie_info.function_num,
        address = pmu_cmd.pci_info_address_list.SUBVID_SUBDID_INFO_ADDR,
        read_length = 4
    }
    local bus = c_object_manage.get_instance().bus
    local info = pmu_cmd.get_info_from_pmu(bus, payload)
    if info == nil or #info < 4 then
        return
    end
    local sub_device_id = ((info[4] & 0xff) << 8) + (info[3] & 0xff)
    local sub_vendor_id = ((info[2] & 0xff) << 8) + (info[1] & 0xff)
    if sub_vendor_id == 0xffff and sub_device_id == 0xffff then
        return
    end
    return sub_vendor_id, sub_device_id
end

local function parse_vid_did_svid_sdid(pcie_info)
    local ok, vendor_id, device_id = pcall(get_pcie_vid_did_info, pcie_info)
    if not ok or not vendor_id then
        return false
    end
    local ret, sub_vendor_id, sub_device_id = pcall(get_pcie_sub_vid_did_info, pcie_info)
    if not ret or not sub_vendor_id then
        return false
    end
    log:info('vendor_id: %s, device_id: %s, sub_vendor_id:%s, sub_device_id: %s', vendor_id,
        device_id, sub_vendor_id, sub_device_id)
    return true, vendor_id, device_id, sub_vendor_id, sub_device_id
end

local function verify_vid_did_svid_sdid(pcie_info)
    -- 校验四元组仅使用第一个返回值
    return parse_vid_did_svid_sdid(pcie_info)
end

function card_resource_update_bdf.update_1822_port_bdf(resource_obj, ports)
    log:notice("update 1822 port bdf start, name:%s", resource_obj.NodeId)

    local retries = 0
    repeat
        resource_obj.tasks:sleep_ms(1000)
        retries = retries + 1
    until (resource_obj.DevBus and resource_obj.DevBus ~= 0) or retries > 120

    if not resource_obj.DevBus or resource_obj.DevBus == 0 then
        log:error('update 1822 port bdf, the value of devbus is invalid')
        return
    end

    local port_num = resource_obj.NetworkPortCount
    local count = 0
    local pcie_info, val, ok, flag
    for _, port in pairs(ports) do
        log:notice("update 1822 port bdf, name:%s, DevBus=%s, DevDevice=%s, port_id=%s, DevFunction=%s",
            resource_obj.NodeId, resource_obj.DevBus, resource_obj.DevDevice, port.PortID, resource_obj.DevFunction)
        pcie_info = {
            system_id = 1,
            is_local = false,
            cpu_id = resource_obj.SocketId,
            bus_num = resource_obj.DevBus,
            device_num = resource_obj.DevDevice,
            function_num = resource_obj.DevFunction + port.PortID
        }
        flag = verify_vid_did_svid_sdid(pcie_info)
        ok, val = pcall(string.format, '0000:%02x:%02x.%01x', pcie_info.bus_num,
            pcie_info.device_num, pcie_info.function_num)
        if flag and ok then
            log:notice("update 1822 port bdf, name:%s, port = %s, set port BDF val=%s",
                resource_obj.NodeId, port.PortID, val)
            port.BDF = val
            count = count + 1
        end
    end
    if count ~= port_num then
        log:error('update 1822 port bdf failed, name:%s', resource_obj.NodeId)
    end
end

function card_resource_update_bdf.update_port_bdf(resource_obj, ports)
    local port_num = resource_obj.NetworkPortCount
    local count = 0
    local pcie_info, ret, val, ok
    if resource_obj.Type ~= 1 and (not resource_obj.DevBus or resource_obj.DevBus == 0) then
        log:notice("update port bdf return, DevBus check fail, name:%s", resource_obj.NodeId)
        return
    end

    if resource_obj.Type == 1 and (not resource_obj.Bus or resource_obj.Bus == 0) then
        log:notice("update port bdf return, Bus check fail, name:%s", resource_obj.NodeId)
        return
    end

    for _, port in pairs(ports) do
        log:notice("update port bdf, name:%s, DevBus=%s, Bus=%s, port_id=%s, Type=%s",
            resource_obj.NodeId, resource_obj.DevBus, resource_obj.Bus, port.PortID, resource_obj.Type)
        if resource_obj.DevBus == 0 and resource_obj.Type == 1 then
            pcie_info = {
                system_id = 1,
                is_local = false,
                cpu_id = resource_obj.SocketId,
                bus_num = resource_obj.Bus + 1,
                device_num = resource_obj.Device,
                function_num = port.PortID
            }
            ret = verify_vid_did_svid_sdid(pcie_info)
        else
            pcie_info = {
                system_id = 1,
                is_local = false,
                cpu_id = resource_obj.SocketId,
                bus_num = resource_obj.DevBus,
                device_num = resource_obj.DevDevice,
                function_num = port.PortID
            }
            ret = verify_vid_did_svid_sdid(pcie_info)
        end
        ok, val = pcall(string.format, '0000:%02x:%02x.%01x', pcie_info.bus_num,
            pcie_info.device_num, pcie_info.function_num)
        log:info("update port bdf, name:%s, val:%s", resource_obj.NodeId, val)
        if ret and ok and port.BDF == '' then
            port.BDF = val
            count = count + 1
        end
    end
    if count ~= port_num then
        log:error('update port(name:%s) bdf failed', resource_obj.NodeId)
    end
end

return card_resource_update_bdf