-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local lu = require 'luaunit'
local c_object_manage = require 'mc.orm.object_manage'
local npu_imu_cmd = require 'npu.hdk_cmd'
local fructl = require 'infrastructure.fructl'
local ipmi = require 'ipmi'
skynet = require 'skynet'
local comp_code = ipmi.types.Cc
local bs = require 'mc.bitstring'
local c_object_manage = require 'mc.orm.object_manage'

TEST_hdk_cmd = {}

local function mock_get_power_status(...)
    return 'ON'
end

local function mock_get_instance(...)
    return {
        bus = 'test'
    }
end

local function  mock_ipmi_request(...)
    return 0x01, ''
end

-- 对于日志收集等长命令需要根据偏移返回
local function update_mock_resp(mock_resp, offset)
    if offset == 0 then
        return mock_resp
    end

    -- 根据返回体格式可能需要增加ifelse分支
    if mock_resp.tail then
        mock_resp.tail = string.sub(mock_resp.tail, 1 + offset, #mock_resp.tail)
    elseif mock_resp.data then
        mock_resp.data = string.sub(mock_resp.data, 1 + offset, #mock_resp.data)
    end

    return mock_resp
end

local function mock_ipmi(mock_data)
    local mocak_comp_code, mock_resp, req_format, resp_format = mock_data.comp_code, mock_data.mock_resp,
        mock_data.req_format, mock_data.resp_format
    ipmi.request = function (...)
        local _, _, req = ...
        local payload = req_format:unpack(req.Payload)
        local offset, _ = payload.offset, payload.data_length
        mock_resp = update_mock_resp(mock_resp, offset)
        mock_resp = resp_format:pack(mock_resp)
        return mocak_comp_code, mock_resp
    end
end

local get_power_status = fructl.get_power_status
local ipmi_request = ipmi.request
local get_instance = c_object_manage.get_instance

function TEST_hdk_cmd:setUp()
    fructl.get_power_status = mock_get_power_status
    c_object_manage.get_instance = mock_get_instance
    ipmi.request = mock_ipmi_request
end


function TEST_hdk_cmd:tearDown()
    fructl.get_power_status = get_power_status
    c_object_manage.get_instance = get_instance
    ipmi.request = ipmi_request
end

function TEST_hdk_cmd:test_get_info_from_imu()
    local obj = npu_imu_cmd.get_info_from_imu(1, 1)
    lu.assertEquals(obj.ip_addr, 'N/A')
end

local test_get_op_runtime_info<const> = {
    comp_code = comp_code.Success,
    req_format = bs.new([[<<
        0xDB0700:3/unit:8,
        cmd:1/unit:8,
        lun:1/unit:8,
        request_parameter:1/unit:8,
        op_cmd:1/unit:8,
        op_fun:1/unit:8,
        offset:4/unit:8,
        data_length:4/unit:8,
        request_type:1/unit:8
        >>]]),
    resp_format = bs.new([[<<
        error_code:1/unit:8,
        opcode:2/little-unit:8,
        total_length:4/little-unit:8,
        length:4/little-unit:8,
        cmd_version:1/unit:8,

        # support_info 有7个字节,每个bit代表接下来结构体中对应顺序的属性是否支持
        runtime_supported:1,
        poweron_time_supported:1,
        poweron_count_supported:1,
        power_status_supported:1,
        odsp_temp_supported:1,
        odsp_high_heat_time_supported:1,
        laser_run_time_supported:1,
        laser_temp_supported:1,
        laser_core_temp_supported:1,
        reserved:47,

        runtime:4/little-unit:8,
        poweron_time:4/little-unit:8,
        poweron_count:2/little-unit:8,
        power_status:2/little-unit:8,
        odsp_temp:2/little-unit:8,
        odsp_high_heat_time:4/little-unit:8,
        laser_run_time:4/little-unit:8,
        laser_temp:2/little-unit:8,
        laser_core_temp:2/little-unit:8,
        tail/string
        >>]]),
    mock_resp = {
        error_code = 0,
        opcode = 0,
        total_length = 0,
        length = 0,
        cmd_version = 0,
        runtime_supported = 1,
        poweron_time_supported = 1,
        poweron_count_supported = 1,
        power_status_supported = 1,
        odsp_temp_supported = 1,
        odsp_high_heat_time_supported = 1,
        laser_run_time_supported = 1,
        laser_temp_supported = 1,
        laser_core_temp_supported = 1,
        reserved = 1,
        runtime = 4,
        poweron_time = 4,
        poweron_count = 2,
        power_status = 2,
        odsp_temp = 2,
        odsp_high_heat_time = 4,
        laser_run_time = 4,
        laser_temp = 2,
        laser_core_temp = 0,
        tail =""
    }
}

function TEST_hdk_cmd:test_get_op_runtime_info()
    mock_ipmi(test_get_op_runtime_info)
    local ok, err = pcall(function()
        return npu_imu_cmd.get_op_runtime_info(1)
    end)
    print(err)
    lu.assertEquals(ok, true)
end

local test_get_op_base_info_new<const> = {
    comp_code = comp_code.Success,
    req_format = bs.new([[<<
        0xDB0700:3/unit:8,
        cmd:1/unit:8,
        lun:1/unit:8,
        para:1/unit:8,
        op_cmd:1/unit:8,
        op_fun:1/unit:8,
        offset:4/unit:8,
        data_length:4/unit:8
    >>]]),
    resp_format = bs.new([[<<
        head:3/unit:8,
        total_length:4/little-unit:8,
        length:4/little-unit:8,
        vendor_name:32/string,
        serial_number:32/string,
        part_number:32/string,
        manufacture_date:32/string,
        transceiver_type:1/little-unit:8,
        optical_type:1/little-unit:8,
        channel_num:1/little-unit:8,
        tail/binary
    >>]]),
    mock_resp = {
        head = 0,
        total_length = 131,
        length = 1,
        vendor_name = '11111111111111111111111111111111',
        serial_number = '11111111111111111111111111111111',
        part_number = '11111111111111111111111111111111',
        manufacture_date = '11111111111111111111111111111111',
        transceiver_type = 16,
        optical_type = 2,
        channel_num = 4,
        tail = ''
    }
}

local test_get_op_base_info_old<const> = {
    comp_code = comp_code.Success,
    req_format = bs.new([[<<
        0xDB0700:3/unit:8,
        cmd:1/unit:8,
        lun:1/unit:8,
        para:1/unit:8,
        op_cmd:1/unit:8,
        op_fun:1/unit:8,
        offset:4/unit:8,
        data_length:4/unit:8
    >>]]),
    resp_format = bs.new([[<<
        head:3/unit:8,
        total_length:4/little-unit:8,
        length:4/little-unit:8,
        vendor_name:32/string,
        serial_number:32/string,
        part_number:32/string,
        manufacture_date:32/string,
        transceiver_type:1/little-unit:8,
        tail/binary
    >>]]),
    mock_resp = {
        head = 0,
        total_length = 130,
        length = 1,
        vendor_name = '11111111111111111111111111111111',
        serial_number = '11111111111111111111111111111111',
        part_number = '11111111111111111111111111111111',
        manufacture_date = '11111111111111111111111111111111',
        transceiver_type = 16,
        tail = '\02\01'
    }
}

local test_get_op_base_info_channer_num_err<const> = {
    comp_code = comp_code.Success,
    req_format = bs.new([[<<
        0xDB0700:3/unit:8,
        cmd:1/unit:8,
        lun:1/unit:8,
        para:1/unit:8,
        op_cmd:1/unit:8,
        op_fun:1/unit:8,
        offset:4/unit:8,
        data_length:4/unit:8
    >>]]),
    resp_format = bs.new([[<<
        head:3/unit:8,
        total_length:4/little-unit:8,
        length:4/little-unit:8,
        vendor_name:32/string,
        serial_number:32/string,
        part_number:32/string,
        manufacture_date:32/string,
        transceiver_type:1/little-unit:8,
        tail/binary
    >>]]),
    mock_resp = {
        head = 0,
        total_length = 131,
        length = 1,
        vendor_name = '11111111111111111111111111111111',
        serial_number = '11111111111111111111111111111111',
        part_number = '11111111111111111111111111111111',
        manufacture_date = '11111111111111111111111111111111',
        transceiver_type = 0,
        tail = ''
    }
}

local test_get_op_speed_info<const> = {
    comp_code = comp_code.Success,
    req_format = bs.new([[<<
        0xDB0700:3/unit:8,
        cmd:1/unit:8,
        lun:1/unit:8,
        para:1/unit:8,
        op_cmd:1/unit:8,
        op_fun:1/unit:8,
        offset:4/unit:8,
        data_length:4/unit:8
    >>]]),
    resp_format = bs.new([[<<
        head:3/unit:8,
        total_length:4/little-unit:8,
        length:4/little-unit:8,
        vendor_name:32/string,
        serial_number:32/string,
        part_number:32/string,
        manufacture_date:32/string,
        transceiver_type:1/little-unit:8,
        tail/binary
    >>]]),
    mock_resp = {
        head = 0,
        total_length = 135,
        length = 135,
        vendor_name = '11111111111111111111111111111111',
        serial_number = '11111111111111111111111111111111',
        part_number = '11111111111111111111111111111111',
        manufacture_date = '11111111111111111111111111111111',
        transceiver_type = 16,
        tail = '\03\04\03\xa0\x86\x01\00\x00'
    }
}

function TEST_hdk_cmd:test_get_op_base_info()
    local bus = c_object_manage.get_instance().bus
    -- 模拟ipmi返回
    local tmp = ipmi.request
    mock_ipmi(test_get_op_base_info_new)
    local cc, vendor_name, sn, transceiver_type, manufacture_date, part_number,
        channel_num = npu_imu_cmd.mock_get_op_base_info(bus, 1)
    lu.assertEquals(cc, 0)
    lu.assertEquals(transceiver_type, "400G BASE-SR8")

    print(vendor_name, sn, transceiver_type, manufacture_date, part_number)
    lu.assertEquals(channel_num, 4)

    mock_ipmi(test_get_op_base_info_old)
    local cc1, vendor_name1, sn1, transceiver_type1, manufacture_date1, part_number1,
        channel_num1 = npu_imu_cmd.mock_get_op_base_info(bus, 1)
    lu.assertEquals(cc1, 0)
    print(vendor_name1, sn1, transceiver_type1, manufacture_date1, part_number1)
    lu.assertEquals(channel_num1, 8)
    lu.assertEquals(transceiver_type1, "400G BASE-SR8")

    mock_ipmi(test_get_op_base_info_channer_num_err)
    local cc2, vendor_name2, sn2, transceiver_type2, manufacture_date2, part_number2,
        channel_num2 = npu_imu_cmd.mock_get_op_base_info(bus, 1)
    lu.assertEquals(cc2, 0)
    lu.assertEquals(channel_num2, 8)
    print(vendor_name2, sn2, transceiver_type2, manufacture_date2, part_number2)

    mock_ipmi(test_get_op_speed_info)
    local cc3, vendor_name3, sn3, transceiver_type3, manufacture_date3, part_number3,
        channel_num3 = npu_imu_cmd.mock_get_op_base_info(bus, 1)
    lu.assertEquals(cc3, 0)
    lu.assertEquals(channel_num3, 4)
    lu.assertEquals(transceiver_type3, "100G AOC")
    print(vendor_name3, sn3, transceiver_type3, manufacture_date3, part_number3)
    ipmi.request = tmp
end

local test_get_op_status_info<const> = {
    comp_code = comp_code.Success,
    req_format = bs.new([[<<
        0xDB0700:3/unit:8,
        cmd:1/unit:8,
        lun:1/unit:8,
        para:1/unit:8,
        op_cmd:1/unit:8,
        op_fun:1/unit:8,
        offset:4/unit:8,
        data_length:4/unit:8
        >>]]),
    resp_format = bs.new([[<<
        head:3/unit:8,
        total_length:4/little-unit:8,
        length:4/little-unit:8,
        voltage:4/little-unit:8,
        tx_power1:2/big-unit:8,
        tx_power2:2/big-unit:8,
        tx_power3:2/big-unit:8,
        tx_power4:2/big-unit:8,
        tx_power5:2/big-unit:8,
        tx_power6:2/big-unit:8,
        tx_power7:2/big-unit:8,
        tx_power8:2/big-unit:8,
        rx_power1:2/big-unit:8,
        rx_power2:2/big-unit:8,
        rx_power3:2/big-unit:8,
        rx_power4:2/big-unit:8,
        rx_power5:2/big-unit:8,
        rx_power6:2/big-unit:8,
        rx_power7:2/big-unit:8,
        rx_power8:2/big-unit:8,
        tx_bias1:2/big-unit:8,
        tx_bias2:2/big-unit:8,
        tx_bias3:2/big-unit:8,
        tx_bias4:2/big-unit:8,
        tx_bias5:2/big-unit:8,
        tx_bias6:2/big-unit:8,
        tx_bias7:2/big-unit:8,
        tx_bias8:2/big-unit:8,
        tx_los:4/little-unit:8,
        rx_los:4/little-unit:8,
        tx_lol:4/little-unit:8,
        rx_lol:4/little-unit:8,
        temperature:4/little-unit:8,
        tx_fault:4/little-unit:8,
        host_snr1:32/little-float,
        host_snr2:32/little-float,
        host_snr3:32/little-float,
        host_snr4:32/little-float,
        host_snr5:32/little-float,
        host_snr6:32/little-float,
        host_snr7:32/little-float,
        host_snr8:32/little-float,
        media_snr1:32/little-float,
        media_snr2:32/little-float,
        media_snr3:32/little-float,
        media_snr4:32/little-float,
        media_snr5:32/little-float,
        media_snr6:32/little-float,
        media_snr7:32/little-float,
        media_snr8:32/little-float,
        access_failed:1/little-unit:8,
        tail/string
        >>]]),
    mock_resp = {
        head = 0,
        total_length = 1,
        length = 1,
        voltage = 0,
        tx_power1 = 0,
        tx_power2 = 0,
        tx_power3 = 0,
        tx_power4 = 0,
        tx_power5 = 0,
        tx_power6 = 0,
        tx_power7 = 0,
        tx_power8 = 0,
        rx_power1 = 0,
        rx_power2 = 0,
        rx_power3 = 0,
        rx_power4 = 0,
        rx_power5 = 0,
        rx_power6 = 0,
        rx_power7 = 0,
        rx_power8 = 0,
        tx_bias1 = 0,
        tx_bias2 = 0,
        tx_bias3 = 0,
        tx_bias4 = 0,
        tx_bias5 = 0,
        tx_bias6 = 0,
        tx_bias7 = 0,
        tx_bias8 = 0,
        tx_los = 0,
        rx_los = 0,
        tx_lol = 0,
        rx_lol = 0,
        temperature = 0,
        tx_fault = 0,
        host_snr1 = 1,
        host_snr2 = 2,
        host_snr3 = 3,
        host_snr4 = 4,
        host_snr5 = 5,
        host_snr6 = 6,
        host_snr7 = 7,
        host_snr8 = 8,
        media_snr1 = 11,
        media_snr2 = 22,
        media_snr3 = 33,
        media_snr4 = 44,
        media_snr5 = 55,
        media_snr6 = 66,
        media_snr7 = 77,
        media_snr8 = 88,
        access_failed = 0,
        tail =""
    }
}

function TEST_hdk_cmd:test_get_op_status_info()
    local bus = c_object_manage.get_instance().bus
    -- 模拟ipmi返回
    local tmp = ipmi.request
    mock_ipmi(test_get_op_status_info)
    local cc, voltage, tx_power, rx_power, tx_bias, tx_los, rx_los, tx_lol, rx_lol,
        temperature, tx_fault, host_snr, media_snr, access_failed =
        npu_imu_cmd.mock_get_op_status_info(bus, 1, 4)
    print(voltage, tx_los, rx_los, tx_lol, rx_lol, temperature, tx_fault, access_failed)
    lu.assertEquals(cc, 0)
    lu.assertEquals(#tx_power, 4)
    lu.assertEquals(#rx_power, 4)
    lu.assertEquals(#tx_bias, 4)
    lu.assertEquals(#host_snr, 4)
    lu.assertEquals(#media_snr, 4)

    mock_ipmi(test_get_op_status_info)
    local cc1, voltage1, tx_power1, rx_power1, tx_bias1, tx_los1, rx_los1, tx_lol1, rx_lol1,
        temperature1, tx_fault1, host_snr1, media_snr1, access_failed1 =
        npu_imu_cmd.mock_get_op_status_info(bus, 1, 8)
    print(voltage1, tx_los1, rx_los1, tx_lol1, rx_lol1, temperature1, tx_fault1, access_failed1)
    lu.assertEquals(cc1, 0)
    lu.assertEquals(#tx_power1, 8)
    lu.assertEquals(#rx_power1, 8)
    lu.assertEquals(#tx_bias1, 8)
    lu.assertEquals(#host_snr1, 8)
    lu.assertEquals(#media_snr1, 8)
    ipmi.request = tmp
end

function TEST_hdk_cmd:test_get_npu_cdr_temp_from_imu()    
    ipmi.request = function(bus, cmd, subcmd)
        return 0, "test"
    end
    local ret = {cdr_temp = 1}
    local rspNpuCdrTemp = {}
    rspNpuCdrTemp.unpack = function()
        return ret
    end
    npu_imu_cmd.get_npu_cdr_temp_from_imu(1, true)
    lu.assertEquals(ret.cdr_temp, 1)
end

local function pack(obj)
    local result = ''
    for _, v in ipairs(obj) do
        result = result .. v
    end
    return result
end

function TEST_hdk_cmd:test_get_mac_from_imu()
    local mock_response = pack({
        '\x00\x00\x00',
        '\x00\x00\x00\x00',
        '\x00\x00\x00\x00',
        '\x3a',
        '\x3b',
        '\x3c',
        '\x3d',
        '\x3e',
        '\x3f'
    })
    ipmi.request = function(bus, cmd, subcmd)
        return 0, mock_response
    end

    local ret = npu_imu_cmd.get_info_from_imu(1, 1)
    lu.assertEquals(ret.mac_addr, '3A:3B:3C:3D:3E:3F')
end
