-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at:
--         http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local lu = require 'luaunit'
local c_device_loader = require 'biz_topo.device_loader'
local c_load_info = require 'biz_topo.class.load_info'
local c_device_service = require 'device.device_service'
local cmn = require 'common'

TestDeviceLoader = {}

function TestDeviceLoader:test_parse_pcie_card_bdf_data()
    local data = '\x00\x00\x01\x01\x01\x01\x00\x00\x00\x00\x00\x01\x02\x02\x02'

    local pcie_card_info = c_device_loader.parse_pcie_card_bdf_data(data)
    -- PCIeSlot 1
    lu.assertEquals(pcie_card_info[1].bus_info.segment, 0)
    lu.assertEquals(pcie_card_info[1].bus_info.socket_id, 0)
    lu.assertEquals(pcie_card_info[1].bus_info.bus, 1)
    lu.assertEquals(pcie_card_info[1].bus_info.device, 1)
    lu.assertEquals(pcie_card_info[1].bus_info.func, 1)
    -- PCIeSlot 3
    lu.assertEquals(pcie_card_info[3].bus_info.segment, 0)
    lu.assertEquals(pcie_card_info[3].bus_info.socket_id, 1)
    lu.assertEquals(pcie_card_info[3].bus_info.bus, 2)
    lu.assertEquals(pcie_card_info[3].bus_info.device, 2)
    lu.assertEquals(pcie_card_info[3].bus_info.func, 2)
end

function TestDeviceLoader:test_get_pcie_vid_did_info()
    local bus = {
        call = function()
            return 0, '\x00\x07\xdb\x04\x01\x02\x03\x04'
        end
    }
    local pcie_info = {
        bus_info = {
            socket_id = 1,
            bus = 1,
            device = 1,
            func = 1
        }
    }
    c_device_loader.get_pcie_vid_did_info(1, bus, pcie_info)
end

function TestDeviceLoader:test_load_unload_device()
    local mock_device_load = {
        biz_topo = {
            get_mgmt_connector = function(...)
                return true, nil, {Presence = 1, LoadStatus = 0}
            end
        }
    }
    local load_info = {
        slot_id = 256
    }
    c_device_loader.load_unload_device(mock_device_load, load_info, _, _, _)
    
end

function TestDeviceLoader:test_get_id_auxid_str()
    local load_info = {
        ['id'] = 1,
        ['aux_id'] = 2,
        ['type'] = 'PCIeCard'
    }
    local device_loader = {
        handle_pcie_custom_display = function(ctx, t_id, t_aux_id)
            return t_id, t_aux_id
        end
    }
    local str_id, str_auxid = c_device_loader.get_id_auxid_str(device_loader, nil, load_info)
    lu.assertEquals(str_id, '')
    lu.assertEquals(str_auxid, '')

    load_info = {
        ['id'] = 1,
        ['aux_id'] = 2,
        ['type'] = 'OCPCard'
    }

    str_id, str_auxid = c_device_loader.get_id_auxid_str(device_loader, 'bom', load_info)
    print(str_id)
    lu.assertEquals(str_id, "00000001")
    lu.assertEquals(str_auxid, "00000002")

    load_info = {
        ['id'] = 0xfffffffe,
        ['aux_id'] = 2,
        ['type'] = 'PCIeCard'
    }
    str_id, str_auxid = c_device_loader.get_id_auxid_str(device_loader, 'bom', load_info)
    lu.assertEquals(str_id, "fffffffe")
    lu.assertEquals(str_auxid, "ffffffff")
end

function TestDeviceLoader:test_unload_device_connector()
    local c_get_instance = c_device_service.get_instance
    c_device_service.get_instance = function()
        return {
            clear_flash_checker_event = function()
            end
        }
    end
    local load_info = {
        slot_id = 256
    }
    local connector_info = {
        Id = "fffffffe",
        AuxId = "ffffffff",
    }
    local ok, _ = pcall(function()
        c_device_loader.unload_device_connector(load_info, connector_info, 1)
    end)
    lu.assertEquals(ok, true)

    c_device_service.get_instance = c_get_instance
end

function TestDeviceLoader:test_handle_pcie_custom_display()
    local records = {
        {
            Enabled = true,
            CustomType = "NPU",
            QuadrupleId = "test"
        }
    }
    local device_loader = {
        db = {
            PCIeCardDisplayCustom = {
                QuadrupleId = {
                    eq = function(id)
                        return id
                    end
                }
            },
            select = function (table)
                return {
                    where = function ()
                        return {
                            all = function ()
                                return records
                            end
                        }
                    end
                }
            end
        }
    }
    local id, aux_id = c_device_loader.handle_pcie_custom_display(device_loader, 1, 2)
    lu.assertEquals(id, "test")
    lu.assertEquals(aux_id, "NPU")
    records[1].Enabled = false
    local id, aux_id = c_device_loader.handle_pcie_custom_display(device_loader, 1, 2)
    lu.assertEquals(id, 1)
    lu.assertEquals(aux_id, 2)
end

function TestDeviceLoader:test_set_device_present_by_map()
    package.loaded['biz_topo.device_loader'] = nil
    local c_device_loader = require 'biz_topo.device_loader'

    local query_id_from_pmu = function(_sys_id, _bus, info)
        info.vid = 1
        info.did = 1
        info.sub_vid = 1
        info.sub_did = 1
    end

    local i = 0
    while true do
        i = i + 1
        local name, val = debug.getupvalue(c_device_loader.set_device_present_by_map, i)
        if not name then break end
        if name == "query_id_from_pmu" then          -- 找到目标
            debug.setupvalue(c_device_loader.set_device_present_by_map, i, query_id_from_pmu)
            break
        end
    end

    local slotid2ssbdf = { 
        [1] = {0, 0, 0x1c, 0, 0},
        [2] = {0, 0, 0x40, 0, 0}
    }

    local test_addr = {
        position = "01010101",
        update_multihost_presence = function(self, addr_multi_presence)
            self.addr_multi_presence = addr_multi_presence
        end,
        update_vid_did = function(self, vid, did)
            self.vid = vid
            self.did = did
        end,
        update_device_bdf = function(...) end,
        get_prop = function(...) return 1 end
    }

    c_device_loader.biz_topo = {
        get_pcie_addr_info = function(self, a, b)
            return true, {
                addr = test_addr
            }
        end
    }

    local load_infos
    function c_device_loader:device_load_info_persist(_sys_id, _device_type, new_pcie_load_info)
        load_infos = new_pcie_load_info
        return
    end

    function c_device_loader:task_load_unload_device(...) end

    c_device_loader:set_device_present_by_map(1, "PCIeCard", slotid2ssbdf)
    lu.assertEquals(test_addr.did,1)
    lu.assertEquals(#load_infos,2)
end

function TestDeviceLoader:test_process_pci_info()
    local connector = {
        ref_pcie_addr_info = {
            get_prop = function(_, prop)
                if prop == 'ComponentType' then
                    return 83
                end
                if prop == 'SlotID' then
                    return 1
                end
                if prop == 'SocketID' then
                    return 1
                end
                if prop == 'DevBus' then
                    return 1
                end
                if prop == 'DevDevice' then
                    return 1
                end
                if prop == 'DevFunction' then
                    return 1
                end
                if prop == 'Segment' then
                    return 1
                end
                if prop == 'MultihostPresence' then
                    return 1
                end
            end,
            update_device_bdf = function()
            end,
            update_multihost_presence = function()
            end,
            update_vid_did = function()
            end
        }
    }
    local persistent_load_info = {
        ['OCPCard'] = {
            [1] = nil
        }
    }
    local device_loader_obj = {
        hot_plug_filter = {},
        task_load_unload_device = function()
        end,
        device_load_info_persist_ocp = function()
        end,
        persistent_load_info = persistent_load_info
    }

    local ok, _ = pcall(function()
        c_device_loader.process_pci_info(device_loader_obj, connector, false)

        device_loader_obj.persistent_load_info = {
            ['OCPCard'] = {
                [1] = true
            }
        }
        c_device_loader.process_pci_info(device_loader_obj, connector, false)

        device_loader_obj.hot_plug_filter = {
            ['1_83'] = true
        }
        c_device_loader.process_pci_info(device_loader_obj, connector, false)

        device_loader_obj.hot_plug_filter = {
            ['1_83'] = false
        }
        c_device_loader.process_pci_info(device_loader_obj, connector, true)

        connector.ref_pcie_addr_info.get_prop = function(_, prop)
            if prop == 'ComponentType' then
                return 83
            end
            if prop == 'SlotID' then
                return 1
            end
            if prop == 'SocketID' then
                return 1
            end
            if prop == 'DevBus' then
                return 0xff
            end
            if prop == 'DevDevice' then
                return 1
            end
            if prop == 'DevFunction' then
                return 1
            end
            if prop == 'Segment' then
                return 1
            end
            if prop == 'MultihostPresence' then
                return 1
            end
        end
        c_device_loader.process_pci_info(device_loader_obj, connector, true)
    end)
    
    lu.assertEquals(ok, true)
end

function TestDeviceLoader:test_ocp_load_from_tc()
    local connector = {}
    local device_loader_obj = {
        biz_topo = {
            get_objs = function(_, prop)
                if prop == 'PCIeBizConnector' then
                    return {
                        [1] = connector
                    }
                end
            end
        },
        process_pci_info = function()
        end
    }
    local ok, _ = pcall(function()
        c_device_loader.ocp_load_from_tc(device_loader_obj)

        connector = {
            ref_pcie_addr_info = {
                get_prop = function(_, prop)
                    if prop == 'ComponentType' then
                        return 83
                    end
                    if prop == 'SlotID' then
                        return 1
                    end
                    if prop == 'SocketID' then
                        return 1
                    end
                    if prop == 'DevBus' then
                        return 1
                    end
                    if prop == 'DevDevice' then
                        return 1
                    end
                    if prop == 'DevFunction' then
                        return 1
                    end
                    if prop == 'Segment' then
                        return 1
                    end
                    if prop == 'MultihostPresence' then
                        return 1
                    end
                end,
                update_device_bdf = function()
                end,
                update_multihost_presence = function()
                end,
                update_vid_did = function()
                end
            }
        }
        c_device_loader.ocp_load_from_tc(device_loader_obj)

        connector.ref_mgmt_connector_tianchi = {
            Presence = 0
        }
        c_device_loader.ocp_load_from_tc(device_loader_obj)

        connector.ref_mgmt_connector = {
            Presence = 1
        }
        c_device_loader.ocp_load_from_tc(device_loader_obj)

        connector.ref_mgmt_connector_tianchi = {
            Presence = 1
        }
        connector.ref_mgmt_connector = {
            Presence = 0
        }
        c_device_loader.ocp_load_from_tc(device_loader_obj)
    end)
    lu.assertEquals(ok, true)
end

function TestDeviceLoader:test_device_load_info_persist_ocp()
    local persistent_load_info = {
        ['OCPCard'] = {
            [1] = {
                SlotID = 1,
                ID = 1,
                AuxID = 1,
                MultihostPresence = 1,
                save = function()
                end
            }
        }
    }
    local device_loader_obj = {
        persistent_load_info = persistent_load_info
    }
    local info = {
        id = 1,
        aux_id = 2
    }
    local ok, _ = pcall(function()
        c_device_loader.device_load_info_persist_ocp(device_loader_obj, 1, 'OCPCard', 1, info)
    end)
    lu.assertEquals(ok, true)
    lu.assertEquals(persistent_load_info['OCPCard'][1].AuxID, 2)
end