-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at:
--         http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local lu = require 'luaunit'
local open_db = require 'pcie_device.db'

TestPcieFunction = {}

local function construct_db()
    local ok, datas = pcall(require, 'pcie_device.datas')
    if not ok then
        datas = nil -- 如果没有datas配置，证明当前组件不需要datas，仅打开数据库
    end
    local db = open_db(':memory:', datas)
    return db
end

local PCIEFUNCTION_INTERFACE = 'bmc.kepler.Systems.PCIeDevice.PCIeFunction'

function TestPcieFunction:test_set_prop()
    local c_pcie_function = require 'device.class.pcie_function'
    c_pcie_function.mds = {
        LocalProp = 1,
        SameProp = 1,
        [PCIEFUNCTION_INTERFACE] = {
            InfProp = 1,
            SameProp = 2
        },
        property_changed = {
            on = function()
            end
        }
    }
    lu.assertEquals(c_pcie_function:get_prop('SameProp'), 1)
    lu.assertEquals(c_pcie_function:get_prop('SameProp', PCIEFUNCTION_INTERFACE), 2)
    lu.assertEquals(c_pcie_function:get_prop('LocalProp'), 1)
    lu.assertEquals(c_pcie_function:get_prop('InfProp', PCIEFUNCTION_INTERFACE), 1)
    -- 如果本地属性中不存在，从默认Interface属性中查找
    lu.assertEquals(c_pcie_function:get_prop('InfProp'), 1)

    c_pcie_function:set_prop('SameProp', 10)
    lu.assertEquals(c_pcie_function:get_prop('SameProp'), 10)
    lu.assertEquals(c_pcie_function:get_prop('SameProp', PCIEFUNCTION_INTERFACE), 2)
    c_pcie_function:set_prop('SameProp', 11, PCIEFUNCTION_INTERFACE)
    lu.assertEquals(c_pcie_function:get_prop('SameProp'), 10)
    lu.assertEquals(c_pcie_function:get_prop('SameProp', PCIEFUNCTION_INTERFACE), 11)
    c_pcie_function:set_prop('LocalProp', 12)
    lu.assertEquals(c_pcie_function:get_prop('LocalProp'), 12)
    c_pcie_function:set_prop('InfProp', 13, PCIEFUNCTION_INTERFACE)
    lu.assertEquals(c_pcie_function:get_prop('InfProp'), 13, PCIEFUNCTION_INTERFACE)
    -- 如果本地属性中不存在，从默认Interface属性中查找
    c_pcie_function:set_prop('InfProp', 14)
    lu.assertEquals(c_pcie_function:get_prop('InfProp'), 14)
    lu.assertEquals(c_pcie_function:get_prop('InfProp'), 14, PCIEFUNCTION_INTERFACE)
end

function TestPcieFunction:test_uptree()
    package.preload["mc.mdb.mdb_service"] = function()
        return {
            get_path = function()
            end
        }
    end

    package.preload["mc.mdb"] = function()
        return {
            get_object = function(...) 
                return {
                    ["Bus"] = 0x60,
                    ["Device"] = 0x0,
                    ["Function"] = 0x0
                }
            end
        }
    end

    package.preload["mc.mdb.object_manage"] = function()
        return {
            create_object = function(_class_name, _path, _name, func)
                local mds = {}
                func(mds)
                return mds
            end
        }
    end

    package.preload["mc.class_mgnt"] = function()
        return function(...)
            return {
                remove = function(...)
                end
            }
        end
    end

    package.loaded['device.class.pcie_function'] = nil
    for name, _ in pairs(package.loaded) do
        if name:match("^mc%.") or
           name:match("^pcie_device%.") or
           name == "device.class.pcie_function" then
            package.loaded[name] = nil
        end
    end
    local c_pcie_function = require 'device.class.pcie_function'

    local record = {
        ["SegmentNumber"] = 0,
        ["LogicProcessorId"] = 0,
        ["DeviceType"] = "PCIeCard",
        ["BusNumber"] = 0x40,
        ["DeviceNumber"] = 0,
        ["FunctionNumber"] = 0,
        ["SlotId"] = 2,
        save = function()
        end,
        delete = function()
        end
    }
    c_pcie_function:ctor({}, record, construct_db())
    c_pcie_function.mds = {
        register_mdb_objects = function()
        end,
        unregister_mdb_objects = function()
        end,
        remove = function()
        end
    }
    local card_info = {
        SlotId = 1,
        SegmentNumber = 0,
        LogicProcessorId = 1,
        BusNumber = 0x1c,
        DeviceNumber = 0,
        FunctionNumber = 0
    }
    c_pcie_function:update_info("PCIeCard", 1, card_info)
    local pcie_device = {
        ["mds"] = {
            ["DeviceType"] = 8,
            ["FunctionProtocol"] = "PCIe",
            ["FunctionType"] = "Physical",
            ["BaseClassCode"] = 0,
            ["SubClassCode"] = 0,
            ["ProgrammingInterface"] = 0
        },
        get_prop = function()
            return "ObjectName"
        end
    }
    local vdss = { 0x15e1, 0x0007, 0x15e1, 0x1154}

    local c_object = require 'mc.orm.object'
    local mdb_pcie_function = c_object('PCIeFunction')
    local new_mock = mdb_pcie_function.new
    mdb_pcie_function.new = function ()
    end

    c_pcie_function:uptree(pcie_device, vdss)
    lu.assertEquals(c_pcie_function.mds.FunctionProtocol, "PCIe")
    lu.assertEquals(c_pcie_function.mds.FunctionType, "Physical")

    c_pcie_function:delete()
    c_pcie_function:uptree(pcie_device, vdss)
    lu.assertEquals(c_pcie_function.mds.ObjectName, "Function_8_2_64_0_0")

    package.loaded['device.class.pcie_function'] = nil
    for name, _ in pairs(package.loaded) do
        if name:match("^mc%.") or
           name:match("^pcie_device%.") or
           name == "device.class.pcie_function" then
            package.loaded[name] = nil
        end
    end
    mdb_pcie_function.new = new_mock
end

function TestPcieFunction:test_ras_record()
    local c_pcie_function = require 'device.class.pcie_function'
    local record = {
        DeviceType = 'PCIeCard',
        BusNumber = 1,
        DeviceNumber = 2,
        FunctionNumber = 3,
        SlotId = 4
    }
    local pcie_function_obj = c_pcie_function.new({}, record, construct_db())
    pcie_function_obj.mds = {}
    pcie_function_obj:update_ras_record()
    lu.assertEquals(pcie_function_obj.mds.FatalErrorCount, nil)
    lu.assertEquals(pcie_function_obj.mds.NonFatalErrorCount, nil)
    lu.assertEquals(pcie_function_obj.mds.BadDLLPCount, nil)
    lu.assertEquals(pcie_function_obj.mds.BadTLPCount, nil)
    lu.assertEquals(pcie_function_obj.mds.UnsupportedRequestCount, nil)
    lu.assertEquals(pcie_function_obj.mds.CorrectableErrorOverfrequencyCount, nil)

    pcie_function_obj.db.PCIeFunction({
        DeviceType = 'PCIeCard',
        SlotId = 4,
        BDF = '1:2.3',
        FatalErrorCount = 1,
        NonFatalErrorCount = 2,
        BadDLLPCount = 3,
        BadTLPCount = 4,
        UnsupportedRequestCount = 5,
        CorrectableErrorOverfrequencyCount = 6}):save()
    pcie_function_obj:update_ras_record()
    lu.assertEquals(pcie_function_obj.mds.FatalErrorCount, 1)
    lu.assertEquals(pcie_function_obj.mds.NonFatalErrorCount, 2)
    lu.assertEquals(pcie_function_obj.mds.BadDLLPCount, 3)
    lu.assertEquals(pcie_function_obj.mds.BadTLPCount, 4)
    lu.assertEquals(pcie_function_obj.mds.UnsupportedRequestCount, 5)
    lu.assertEquals(pcie_function_obj.mds.CorrectableErrorOverfrequencyCount, 6)
    pcie_function_obj:delete_ras_record()
    local db_pcie_function = pcie_function_obj.db:select(pcie_function_obj.db.PCIeFunction):where(
        pcie_function_obj.db.PCIeFunction.DeviceType:eq('PCIeCard'),
        pcie_function_obj.db.PCIeFunction.SlotId:eq(4),
        pcie_function_obj.db.PCIeFunction.BDF:eq('1:2.3')
    ):first()
    lu.assertEquals(db_pcie_function, nil)
end