-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local log = require 'mc.logging'
local c_object_manage = require 'mc.orm.object_manage'
local pmu_cmd = require 'imu'
local skynet = require 'skynet'

local card_update_identifier = {}

local function get_pcie_vid_did_info(pcie_info)
    local payload = {
        system_id = pcie_info.system_id,
        is_local = pcie_info.is_local,
        cpu_id = pcie_info.cpu_id,
        bus_num = pcie_info.bus_num,
        device_num = pcie_info.device_num,
        function_num = pcie_info.function_num,
        address = pmu_cmd.pci_info_address_list.VID_DID_INFO_ADDR,
        read_length = 4
    }
    local bus = c_object_manage.get_instance().bus
    local info = pmu_cmd.get_info_from_pmu(bus, payload)
    if info == nil or #info < 4 then
        return
    end
    local device_id = ((info[4] & 0xff) << 8) + (info[3] & 0xff)
    local vendor_id = ((info[2] & 0xff) << 8) + (info[1] & 0xff)
    if vendor_id == 0xffff and device_id == 0xffff then
        return
    end
    return vendor_id, device_id
end

local function get_pcie_sub_vid_did_info(pcie_info)
    local payload = {
        system_id = pcie_info.system_id,
        is_local = pcie_info.is_local,
        cpu_id = pcie_info.cpu_id,
        bus_num = pcie_info.bus_num,
        device_num = pcie_info.device_num,
        function_num = pcie_info.function_num,
        address = pmu_cmd.pci_info_address_list.SUBVID_SUBDID_INFO_ADDR,
        read_length = 4
    }
    local bus = c_object_manage.get_instance().bus
    local info = pmu_cmd.get_info_from_pmu(bus, payload)
    if info == nil or #info < 4 then
        return
    end
    local sub_device_id = ((info[4] & 0xff) << 8) + (info[3] & 0xff)
    local sub_vendor_id = ((info[2] & 0xff) << 8) + (info[1] & 0xff)
    if sub_vendor_id == 0xffff and sub_device_id == 0xffff then
        return
    end
    return sub_vendor_id, sub_device_id
end

local function parse_vid_did_svid_sdid(pcie_info)
    local ok, vendor_id, device_id = pcall(get_pcie_vid_did_info, pcie_info)
    if not ok or not vendor_id then
        return false
    end
    local ret, sub_vendor_id, sub_device_id = pcall(get_pcie_sub_vid_did_info, pcie_info)
    if not ret or not sub_vendor_id then
        return false
    end
    log:info('vendor_id: %s, device_id: %s, sub_vendor_id:%s, sub_device_id: %s', vendor_id,
        device_id, sub_vendor_id, sub_device_id)
    return true, vendor_id, device_id, sub_vendor_id, sub_device_id
end

local function verify_vid_did_svid_sdid(pcie_info)
    -- 校验四元组仅使用第一个返回值
    return parse_vid_did_svid_sdid(pcie_info)
end

function card_update_identifier.update_1822_port_bdf(ports, orm_obj)
    log:notice("update 1822 port bdf start, name:%s", orm_obj.NodeId)

    local retries = 0
    repeat
        orm_obj.tasks:sleep_ms(1000)
        retries = retries + 1
    until (orm_obj.DevBus and orm_obj.DevBus ~= 0) or retries > 120

    if not orm_obj.DevBus or orm_obj.DevBus == 0 then
        log:error('update 1822 port bdf, the value of devbus is invalid')
        return
    end

    local port_num = orm_obj.NetworkPortCount
    local count = 0
    local pcie_info, val, ok, flag
    for _, port in pairs(ports) do
        log:notice("update 1822 port bdf, name:%s, DevBus=%s, DevDevice=%s, port_id=%s, DevFunction=%s",
            orm_obj.NodeId, orm_obj.DevBus, orm_obj.DevDevice, port.PortID, orm_obj.DevFunction)
        pcie_info = {
            system_id = 1,
            is_local = false,
            cpu_id = orm_obj.SocketId,
            bus_num = orm_obj.DevBus,
            device_num = orm_obj.DevDevice,
            function_num = orm_obj.DevFunction + port.PortID
        }
        flag = verify_vid_did_svid_sdid(pcie_info)
        ok, val = pcall(string.format, '0000:%02x:%02x.%01x', pcie_info.bus_num,
            pcie_info.device_num, pcie_info.function_num)
        if flag and ok then
            log:notice("update 1822 port bdf, name:%s, port = %s, set port BDF val=%s",
                orm_obj.NodeId, port.PortID, val)
            port.BDF = val
            count = count + 1
        end
    end
    if count ~= port_num then
        log:error('update 1822 port bdf failed, name:%s', orm_obj.NodeId)
    end
end

function card_update_identifier.update_port_bdf(ports, orm_obj)
    local port_num = orm_obj.NetworkPortCount
    local count = 0
    local pcie_info, ret, val, ok
    if orm_obj.Type ~= 1 and (not orm_obj.DevBus or orm_obj.DevBus == 0) then
        log:notice("update port bdf return, DevBus check fail, name:%s", orm_obj.NodeId)
        return
    end

    if orm_obj.Type == 1 and (not orm_obj.Bus or orm_obj.Bus == 0) then
        log:notice("update port bdf return, Bus check fail, name:%s", orm_obj.NodeId)
        return
    end

    for _, port in pairs(ports) do
        log:notice("update port bdf, name:%s, DevBus=%s, Bus=%s, port_id=%s, Type=%s",
            orm_obj.NodeId, orm_obj.DevBus, orm_obj.Bus, port.PortID, orm_obj.Type)
        if orm_obj.DevBus == 0 and orm_obj.Type == 1 then
            pcie_info = {
                system_id = 1,
                is_local = false,
                cpu_id = orm_obj.SocketId,
                bus_num = orm_obj.Bus + 1,
                device_num = orm_obj.Device,
                function_num = port.PortID
            }
            ret = verify_vid_did_svid_sdid(pcie_info)
        else
            pcie_info = {
                system_id = 1,
                is_local = false,
                cpu_id = orm_obj.SocketId,
                bus_num = orm_obj.DevBus,
                device_num = orm_obj.DevDevice,
                function_num = port.PortID
            }
            ret = verify_vid_did_svid_sdid(pcie_info)
        end
        ok, val = pcall(string.format, '0000:%02x:%02x.%01x', pcie_info.bus_num,
            pcie_info.device_num, pcie_info.function_num)
        log:info("update port bdf, name:%s, val:%s", orm_obj.NodeId, val)
        if ret and ok and port.BDF == '' then
            port.BDF = val
            count = count + 1
        end
    end
    if count ~= port_num then
        log:error('update port(name:%s) bdf failed', orm_obj.NodeId)
    end
end

local function get_vid_did_svid_sdid(pcie_info)
    local ok, vendor_id, device_id, sub_vendor_id, sub_device_id = parse_vid_did_svid_sdid(pcie_info)
    if not ok then
        return false
    end
    pcie_info.did = device_id
    pcie_info.vid = vendor_id
    pcie_info.sdid = sub_device_id
    pcie_info.svid = sub_vendor_id
    return true
end

function card_update_identifier.update_quater_info(orm_obj)
    if STOP_INIT then
        -- ut不进行这里，否则会死循环
        return
    end
    skynet.sleep(12000) -- 启动阶段等待2分钟后更新网卡四元组
    log:notice('init quater info')
    local ret
    local pcie_info = {
        system_id = 1,
        is_local = false,
        cpu_id = orm_obj.SocketId,
        did = 0,
        vid = 0,
        sdid = 0,
        svid = 0
    }
    while true do
        if orm_obj.DevBus == 0 then
            goto continue
        end
        pcie_info.bus_num = orm_obj.DevBus
        pcie_info.device_num = orm_obj.DevDevice
        pcie_info.function_num = orm_obj.DevFunction
        ret = get_vid_did_svid_sdid(pcie_info)
        if ret then
            orm_obj.VendorID = string.format('0x%04x', pcie_info.vid)
            orm_obj.DeviceID = string.format('0x%04x', pcie_info.did)
            orm_obj.SubsystemVendorID = string.format('0x%04x', pcie_info.svid)
            orm_obj.SubsystemDeviceID = string.format('0x%04x', pcie_info.sdid)
        end
        ::continue::
        -- 10s轮询一次
        orm_obj:sleep_ms(10000)
    end
end

return card_update_identifier