-- Copyright (c) 2025 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local lu = require 'luaunit'
local defs = require 'independent_vrd.ind_vrd_defs'

-- 测试套件
TestSd500x = {}

-- Mock辅助函数
local function setup_mocks()
    -- 保存原始模块
    local originals = {}
    originals.skynet = package.loaded['skynet']
    originals.mc_crc8 = package.loaded['mc.crc8']
    originals.parse_hex = package.loaded['independent_vrd.tool.parse_hex']
    originals.sd500x = package.loaded['independent_vrd.chip.sd500x_pmbus']

    -- 设置mock
    package.loaded['skynet'] = require 'skynet'
    package.loaded['mc.crc8'] = function(data)
        local sum = 0
        for i = 1, #data do
            sum = sum + string.byte(data, i)
        end
        return sum % 256
    end
    package.loaded['independent_vrd.tool.parse_hex'] = function(path, expected_len)
        local buffer = string.rep('\x00', expected_len)
        return {
            ret = 0,
            real_len = expected_len,
            buffer = buffer
        }
    end

    return originals
end

local function restore_mocks(originals)
    package.loaded['skynet'] = originals.skynet
    package.loaded['mc.crc8'] = originals.mc_crc8
    package.loaded['independent_vrd.tool.parse_hex'] = originals.parse_hex
    package.loaded['independent_vrd.chip.sd500x_pmbus'] = nil
end

-- 创建mock chip对象
local function create_mock_chip(addr)
    addr = addr or 0x50
    return {
        Read = function(self, ctx, cmd, len)
            local crc8 = package.loaded['mc.crc8']
            local head = string.char(addr, cmd, addr | 1)
            local data = string.rep('\x01', len - 1)
            local full = head .. data
            local crc_val = crc8(full)
            return data .. string.char(crc_val)
        end,
        Write = function(self, ctx, cmd, data)
            return true
        end,
        BatchWrite = function(self, ctx, batch)
            return true
        end,
        PluginRequest = function(self, ctx, plugin, method, args)
            local skynet = require 'skynet'
            local results = {}
            local batch_reads = skynet.unpack(args)
            for i = 1, #batch_reads do
                table.insert(results, string.rep('\x01', 17))
            end
            return skynet.packstring(true, results)
        end
    }
end

-- 创建sd500x实例并返回模块引用
local function create_sd500x_instance()
    local sd500x = require 'independent_vrd.chip.sd500x_pmbus'
    local obj = {
        VrdId = 1,
        UID = 'test_uid',
        Address = 0x50,
        CompAddress = 0x51,
        RefChip = create_mock_chip(0x50),
        CompRefChip = create_mock_chip(0x51),
        ValidateReg = 0,
        UpgradeFileName = {
            FirmwareFileName = 'test_cfg.hex',
            BootFileName = 'test_boot.hex'
        }
    }
    local tab = { object = obj }
    local instance = sd500x.new(tab, 1)
    return instance, sd500x
end

-- 测试：get_version
function TestSd500x:test_get_version()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()
    -- Mock返回版本号1
    instance.read_data = function(self, chip, cmd, datalen, addr)
        return defs.RET.OK, string.char(1)
    end

    local version = instance:get_version()
    lu.assertNotNil(version)
    lu.assertEquals(type(version), 'number')
    lu.assertEquals(version, 1)

    restore_mocks(originals)
end

-- 测试：check_chip_accessible
function TestSd500x:test_check_chip_accessible()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()

    local result = instance:check_chip_accessible()
    lu.assertTrue(result)

    restore_mocks(originals)
end

-- 测试：check_chip_accessible - 失败场景
function TestSd500x:test_check_chip_accessible_no_refchip()
    local originals = setup_mocks()
    local sd500x = require 'independent_vrd.chip.sd500x_pmbus'
    local obj = {
        VrdId = 1,
        UID = 'test_uid',
        Address = 0x50,
        RefChip = nil,
        CompRefChip = create_mock_chip(0x51)
    }
    local tab = { object = obj }
    local instance = sd500x.new(tab, 1)

    local result = instance:check_chip_accessible()
    lu.assertFalse(result)

    restore_mocks(originals)
end

-- 测试：switch_to_boot_rom_mode
function TestSd500x:test_switch_to_boot_rom_mode()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()

    local result = instance:switch_to_boot_rom_mode()
    lu.assertEquals(result, defs.RET.OK)

    restore_mocks(originals)
end

-- 测试：check_boot_rom_version - 验证成功和版本号格式
function TestSd500x:test_check_boot_rom_version()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()
    -- 版本号格式：data[2]==0 && data[3]==1 && data[4]==0 && data[5]>=1 && data[5]<=9
    local valid_version_data = '\x04\x00\x01\x00\x05' -- 版本 0.1.0.5
    instance.obj.CompRefChip.Read = function(self, ctx, cmd, len)
        local crc8 = package.loaded['mc.crc8']
        local addr = instance.obj.CompAddress
        local head = string.char(addr, cmd, addr | 1)
        local full = head .. valid_version_data
        local crc_val = crc8(full)
        return valid_version_data .. string.char(crc_val)
    end

    local result = instance:check_boot_rom_version()
    lu.assertEquals(result, defs.RET.OK)
    -- 验证版本号格式：0.1.0.5 (字节2=0, 字节3=1, 字节4=0, 字节5=5)
    lu.assertEquals(string.byte(valid_version_data, 2), 0, 'Version byte 2 should be 0')
    lu.assertEquals(string.byte(valid_version_data, 3), 1, 'Version byte 3 should be 1')
    lu.assertEquals(string.byte(valid_version_data, 4), 0, 'Version byte 4 should be 0')
    local version_patch = string.byte(valid_version_data, 5)
    lu.assertTrue(version_patch >= 1 and version_patch <= 9, 'Version byte 5 should be 1-9')

    -- 测试无效版本号
    local invalid_version_data = '\x04\xFF\xFF\xFF\xFF'
    instance.obj.CompRefChip.Read = function(self, ctx, cmd, len)
        local crc8 = package.loaded['mc.crc8']
        local addr = instance.obj.CompAddress
        local head = string.char(addr, cmd, addr | 1)
        local full = head .. invalid_version_data
        local crc_val = crc8(full)
        return invalid_version_data .. string.char(crc_val)
    end
    local result_invalid = instance:check_boot_rom_version()
    lu.assertEquals(result_invalid, defs.RET.ERR, 'Invalid version should return ERR')

    restore_mocks(originals)
end

-- 测试：write_data
function TestSd500x:test_write_data()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()
    local write_called = false
    local write_cmd = nil
    local write_data_received = nil
    instance.obj.RefChip.Write = function(self, ctx, cmd, data)
        write_called = true
        write_cmd = cmd
        write_data_received = data
        return true
    end

    local test_data = '\x01\x02\x03\x04\x05'
    local result = instance:write_data(instance.obj.RefChip, 0xFD, test_data, instance.obj.Address)
    lu.assertEquals(result, defs.RET.OK)
    lu.assertTrue(write_called, 'Write should be called')
    lu.assertEquals(write_cmd, 0xFD)
    -- 验证写入的数据包含原始数据和CRC（原始数据5字节 + CRC 1字节 = 6字节）
    lu.assertEquals(#write_data_received, 6, 'Write data should include CRC')
    lu.assertEquals(string.sub(write_data_received, 1, 5), test_data, 'Write data should contain original data')

    restore_mocks(originals)
end

-- 测试：read_data
function TestSd500x:test_read_data()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()
    local expected_data = string.rep('\x01', 17)
    instance.obj.RefChip.Read = function(self, ctx, cmd, len)
        local crc8 = package.loaded['mc.crc8']
        local addr = instance.obj.Address
        local head = string.char(addr, cmd, addr | 1)
        local data = string.rep('\x01', len - 1)
        local full = head .. data
        local crc_val = crc8(full)
        return data .. string.char(crc_val)
    end

    local result, data = instance:read_data(instance.obj.RefChip, 0xF9, 17, instance.obj.Address)
    lu.assertEquals(result, defs.RET.OK)
    lu.assertNotNil(data)
    lu.assertEquals(#data, 17)
    lu.assertEquals(data, expected_data)

    restore_mocks(originals)
end

-- 测试：batch_write
function TestSd500x:test_batch_write()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()
    local batch_write_called = false
    local batch_write_args = nil
    instance.obj.RefChip.BatchWrite = function(self, ctx, batch)
        batch_write_called = true
        batch_write_args = batch
        return true
    end

    local batch = {
        { 0xF4, string.rep('\x00', 21) },
        { 0xF4, string.rep('\x00', 21) }
    }
    local result = instance:batch_write(instance.obj.RefChip, batch)
    lu.assertEquals(result, defs.RET.OK)
    lu.assertTrue(batch_write_called, 'BatchWrite should be called')
    lu.assertNotNil(batch_write_args)
    lu.assertEquals(#batch_write_args, 2, 'Should have 2 write commands in batch')

    restore_mocks(originals)
end

-- 测试：erase_cfg_blocks
function TestSd500x:test_erase_cfg_blocks()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()
    local erase_calls = {}
    instance.batch_write = function(self, chip, batch)
        erase_calls[#erase_calls + 1] = batch
        return defs.RET.OK
    end

    local result = instance:erase_cfg_blocks()
    lu.assertEquals(result, defs.RET.OK)
    -- 验证擦除命令：主分区4个块 + 备份区4个块 = 8个擦除命令
    lu.assertEquals(#erase_calls, 1, 'Should have one batch write call')
    lu.assertEquals(#erase_calls[1], 8, 'Should erase 8 blocks (4 main + 4 backup)')

    restore_mocks(originals)
end

-- 测试：read_cfg
function TestSd500x:test_read_cfg()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()
    local expected_data = string.rep('\x01', 16) -- 每个块16字节数据
    instance.write_data = function(self, chip, cmd, data, addr)
        return defs.RET.OK
    end
    instance.read_data = function(self, chip, cmd, datalen, addr)
        -- 返回17字节：1字节头 + 16字节数据
        return defs.RET.OK, '\x01' .. string.rep('\x01', 16)
    end

    local config = {
        chip = instance.obj.RefChip,
        addr = instance.obj.Address,
        start_addr = 0x70000000,
        file_len = 32 -- 2个块
    }
    local result, data = instance:read_cfg(config)
    lu.assertEquals(result, defs.RET.OK)
    lu.assertNotNil(data)
    lu.assertEquals(#data, 32)
    -- 验证数据内容：应该是2个16字节块
    lu.assertEquals(data, expected_data .. expected_data)

    restore_mocks(originals)
end

-- 测试：batch_read_cfg
function TestSd500x:test_batch_read_cfg()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()

    local config = {
        chip = instance.obj.RefChip,
        addr = instance.obj.Address,
        start_addr = 0x70000000,
        file_len = 256 -- 16个块
    }
    local result, data = instance:batch_read_cfg(config)
    lu.assertEquals(result, defs.RET.OK)
    lu.assertNotNil(data)
    lu.assertEquals(#data, 256)
    -- 验证数据格式：每个块应该是16字节
    for i = 1, 16 do
        local block = string.sub(data, (i - 1) * 16 + 1, i * 16)
        lu.assertEquals(#block, 16, 'Block ' .. i .. ' should be 16 bytes')
    end

    restore_mocks(originals)
end

-- 测试：verify_data
function TestSd500x:test_verify_data()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()
    local test_data = string.rep('\x01', 256)
    instance.batch_read_cfg = function(self, config)
        return defs.RET.OK, test_data
    end

    local config = {
        chip = instance.obj.RefChip,
        addr = instance.obj.Address,
        start_addr = 0x70000000,
        file_len = 256,
        data = test_data
    }
    local result = instance:verify_data(config)
    lu.assertEquals(result, defs.RET.OK)

    restore_mocks(originals)
end

-- 测试：verify_data - 数据不匹配
function TestSd500x:test_verify_data_mismatch()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()
    local read_data = string.rep('\x01', 256)
    instance.batch_read_cfg = function(self, config)
        return defs.RET.OK, read_data
    end

    local expected_data = string.rep('\x02', 256)
    local config = {
        chip = instance.obj.RefChip,
        addr = instance.obj.Address,
        start_addr = 0x70000000,
        file_len = 256,
        data = expected_data
    }
    local result = instance:verify_data(config)
    lu.assertEquals(result, defs.RET.ERR)
    -- 验证数据确实不匹配：第一个字节就不相同
    lu.assertNotEquals(string.byte(read_data, 1), string.byte(expected_data, 1),
        'Data should be different at first byte')

    restore_mocks(originals)
end

-- 测试：write_cfg
function TestSd500x:test_write_cfg()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()
    local write_calls = {}
    instance.batch_write = function(self, chip, batch)
        write_calls[#write_calls + 1] = batch
        return defs.RET.OK
    end

    local test_data = string.rep('\x00', 32)
    local config = {
        chip = instance.obj.RefChip,
        addr = instance.obj.Address,
        start_addr = 0x70000000,
        file_len = 32,
        head_len = 20,
        data = test_data
    }
    local result = instance:write_cfg(config)
    lu.assertEquals(result, defs.RET.OK)
    -- 验证批量写入被调用，且批次数量正确（32字节 / 16字节每块 = 2个块）
    lu.assertEquals(#write_calls, 1)
    lu.assertEquals(#write_calls[1], 2, 'Should write 2 blocks for 32 bytes')

    restore_mocks(originals)
end

-- 测试：write_file_data
function TestSd500x:test_write_file_data()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()
    local write_cfg_called = false
    local verify_data_called = false
    local write_config = nil
    local verify_config = nil
    instance.write_cfg = function(self, config)
        write_cfg_called = true
        write_config = config
        return defs.RET.OK
    end
    instance.verify_data = function(self, config)
        verify_data_called = true
        verify_config = config
        return defs.RET.OK
    end

    local test_data = string.rep('\x00', 256)
    local config = {
        chip = instance.obj.RefChip,
        addr = instance.obj.Address,
        start_addr = 0x70000000,
        file_len = 256,
        head_len = 20,
        data = test_data
    }
    local result = instance:write_file_data(config)
    lu.assertEquals(result, defs.RET.OK)
    -- 验证写入和验证都被调用
    lu.assertTrue(write_cfg_called, 'write_cfg should be called')
    lu.assertTrue(verify_data_called, 'verify_data should be called')
    -- 验证配置参数正确传递
    lu.assertEquals(write_config.data, test_data, 'Write config should contain test data')
    lu.assertEquals(verify_config.data, test_data, 'Verify config should contain test data')

    restore_mocks(originals)
end

-- 测试：upgrade_cfg
function TestSd500x:test_upgrade_cfg()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()
    local erase_called = false
    local write_main_called = false
    local write_backup_called = false
    instance.erase_cfg_blocks = function(self)
        erase_called = true
        return defs.RET.OK
    end
    instance.write_file_data = function(self, config)
        if config.start_addr == 0x70000000 then
            write_main_called = true
        elseif config.start_addr == 0x70004000 then
            write_backup_called = true
        end
        return defs.RET.OK
    end

    local result = instance:upgrade_cfg('/tmp/')
    lu.assertEquals(result, defs.RET.OK)
    -- 验证升级流程：擦除 -> 写入主分区 -> 写入备份区
    lu.assertTrue(erase_called, 'erase_cfg_blocks should be called')
    lu.assertTrue(write_main_called, 'Main partition should be written')
    lu.assertTrue(write_backup_called, 'Backup partition should be written')

    restore_mocks(originals)
end

-- 测试：switch_to_app_mode
function TestSd500x:test_switch_to_app_mode()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()

    local result = instance:switch_to_app_mode()
    lu.assertEquals(result, defs.RET.OK)

    restore_mocks(originals)
end

-- 测试：get_app_die_id - 验证读取成功和die_id格式
function TestSd500x:test_get_app_die_id()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()
    local expected_die_id = string.rep('\xAB', 20)
    local die_id_data = '\x14' .. expected_die_id -- 第一个字节是长度20
    instance.obj.RefChip.Read = function(self, ctx, cmd, len)
        local crc8 = package.loaded['mc.crc8']
        local addr = instance.obj.Address
        local head = string.char(addr, cmd, addr | 1)
        local full = head .. die_id_data
        local crc_val = crc8(full)
        return die_id_data .. string.char(crc_val)
    end

    local result = instance:get_app_die_id()
    lu.assertEquals(result, defs.RET.OK)
    -- 验证die_id格式：第一个字节是长度，应该在1-20之间
    local die_id_len = string.byte(die_id_data, 1)
    lu.assertTrue(die_id_len >= 1 and die_id_len <= 20, 'Die ID length should be 1-20')
    lu.assertEquals(die_id_len, 20, 'Die ID length should be 20')
    -- 验证die_id数据内容
    local actual_die_id = string.sub(die_id_data, 2, 21)
    lu.assertEquals(actual_die_id, expected_die_id, 'Die ID data should match expected')

    restore_mocks(originals)
end

-- 测试：read_fw_version - 验证读取成功和版本号内容
function TestSd500x:test_read_fw_version()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()
    -- 版本号格式：第一个字节是长度，后面4字节是版本号
    local version_data = '\x04\x01\x02\x03\x04' -- 版本号：1.2.3.4
    instance.obj.RefChip.Read = function(self, ctx, cmd, len)
        local crc8 = package.loaded['mc.crc8']
        local addr = instance.obj.Address
        local head = string.char(addr, cmd, addr | 1)
        local full = head .. version_data
        local crc_val = crc8(full)
        return version_data .. string.char(crc_val)
    end

    local result = instance:read_fw_version()
    lu.assertEquals(result, defs.RET.OK)
    -- 验证版本号内容：1.2.3.4
    lu.assertEquals(string.byte(version_data, 2), 1, 'Version major should be 1')
    lu.assertEquals(string.byte(version_data, 3), 2, 'Version minor should be 2')
    lu.assertEquals(string.byte(version_data, 4), 3, 'Version patch should be 3')
    lu.assertEquals(string.byte(version_data, 5), 4, 'Version build should be 4')

    restore_mocks(originals)
end

-- 测试：upgrade_firmware
function TestSd500x:test_upgrade_firmware()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()
    local upgrade_steps = {}
    instance.enter_boot_rom = function(self)
        upgrade_steps[#upgrade_steps + 1] = 'enter_boot_rom'
        return defs.RET.OK
    end
    instance.write_firmware_to_flash = function(self, data)
        upgrade_steps[#upgrade_steps + 1] = 'write_firmware'
        lu.assertEquals(#data, 48 * 1024, 'Firmware data should be 48KB')
        return defs.RET.OK
    end
    instance.switch_to_app_mode = function(self)
        upgrade_steps[#upgrade_steps + 1] = 'switch_to_app'
        return defs.RET.OK
    end
    instance.get_app_die_id = function(self)
        upgrade_steps[#upgrade_steps + 1] = 'get_die_id'
        return defs.RET.OK
    end
    instance.read_fw_version = function(self)
        upgrade_steps[#upgrade_steps + 1] = 'read_version'
        return defs.RET.OK
    end

    local result = instance:upgrade_firmware('/tmp/')
    lu.assertEquals(result, defs.RET.OK)
    -- 验证升级流程顺序正确
    lu.assertEquals(#upgrade_steps, 5, 'Should execute 5 upgrade steps')
    lu.assertEquals(upgrade_steps[1], 'enter_boot_rom', 'Should enter boot rom first')
    lu.assertEquals(upgrade_steps[2], 'write_firmware', 'Should write firmware second')
    lu.assertEquals(upgrade_steps[3], 'switch_to_app', 'Should switch to app third')

    restore_mocks(originals)
end

-- 测试：upgrade - 完整升级流程
function TestSd500x:test_upgrade_full()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()
    local firmware_upgraded = false
    local cfg_upgraded = false
    instance.upgrade_firmware = function(self, dir)
        firmware_upgraded = true
        return defs.RET.OK
    end
    instance.upgrade_cfg = function(self, dir)
        cfg_upgraded = true
        return defs.RET.OK
    end

    local result = instance:upgrade('/tmp/')
    lu.assertEquals(result, defs.RET.OK)
    -- 验证固件和配置都升级了
    lu.assertTrue(firmware_upgraded, 'Firmware should be upgraded')
    lu.assertTrue(cfg_upgraded, 'Configuration should be upgraded')

    restore_mocks(originals)
end

-- 测试：upgrade - 固件升级失败但配置升级成功
function TestSd500x:test_upgrade_firmware_fail_cfg_ok()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()
    instance.upgrade_firmware = function(self, dir)
        return defs.RET.ERR
    end
    instance.upgrade_cfg = function(self, dir)
        return defs.RET.OK
    end

    local result = instance:upgrade('/tmp/')
    lu.assertEquals(result, defs.RET.OK)

    restore_mocks(originals)
end

-- 测试：valid_vrd
function TestSd500x:test_valid_vrd()
    local originals = setup_mocks()
    local instance = create_sd500x_instance()

    local result = instance:valid_vrd()
    lu.assertEquals(result, defs.RET.OK)
    lu.assertEquals(instance.obj.ValidateReg, 1)

    restore_mocks(originals)
end