-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local bs = require 'mc.bitstring'
local log = require 'mc.logging'
local core = require 'network.core'
local ncsi_protocol_intf = require 'ncsi_protocol_intf'
local ncsi_utils = require 'ncsi.ncsi_protocol.ncsi_utils'
local ncsi_def = require 'ncsi.ncsi_protocol.ncsi_def'
local ncsi_packet = require 'ncsi.ncsi_protocol.ncsi_packet'
local ncsi_parameter = require 'ncsi.ncsi_protocol.ncsi_parameter'

local ncsi_get_status = {}

local GET_LINK_STATUS = 0x0A
local GET_LINK_STATUS_RSP = 0x8A
local GET_LINK_STATUS_REQ_LEN = 0
local GET_LINK_STATUS_RSP_LEN = 16
local GET_LINK_STATUS_PKT_SIZE = 34
local LINK_ACCESS_ERROR = 6

local get_link_status_req_bs = bs.new([[<<
    check_sum:32,
    data:26/string,
    fcs:32
>>]])

local get_link_status_rsp_bs = bs.new([[<<
    rsp_code:16,
    reason_code:16,
    lnk_sts_reserved3:8,
    lnk_sts_tx_flow_control:1,
    lnk_sts_rx_flow_control:1,
    lnk_sts_link_partner8:2,
    lnk_sts_serdes_link:1,
    lnk_sts_oem_link_speed:1,
    lnk_sts_reserved2:2,
    lnk_sts_channel_available:1,
    lnk_sts_link_partner1:1,
    lnk_sts_link_partner2:1,
    lnk_sts_link_partner3:1,
    lnk_sts_link_partner4:1,
    lnk_sts_link_partner5:1,
    lnk_sts_link_partner6:1,
    lnk_sts_link_partner7:1,
    lnk_sts_link_flag:1,
    lnk_sts_speed_duplex:4,
    lnk_sts_negotiate_flag:1,
    lnk_sts_negotiate_complete:1,
    lnk_sts_parallel_detection:1,
    other_indication:32,
    oem_link_status:32,
    check_sum:32,
    pad_data:10/string,
    fcs:32
>>]])

-- 填充请求包的payload
local function fill_req_payload(req_packet, req_len, pkt_size)
    local check_sum = ncsi_utils.get_checksum(req_packet, ncsi_def.PACKET_HEAD_LEN + req_len)
    local crc32 = ncsi_utils.get_crc32(req_packet, ncsi_def.PACKET_HEAD_LEN + pkt_size - 4)
    local payload_data = get_link_status_req_bs:pack({
        check_sum = core.htonl(check_sum),
        data = '',
        fcs = core.htonl(crc32)
    })
    return payload_data
end

-- 写入请求
local function write_get_link_status_req(req_packet, eth_name)
    ncsi_utils.ncsi_cmd_common_config(req_packet)
    req_packet.packet_head.payload_len_hi = (GET_LINK_STATUS_REQ_LEN >> 8) & 0x0f
    req_packet.packet_head.payload_len_lo = GET_LINK_STATUS_REQ_LEN & 0xff
    req_packet.payload = fill_req_payload(req_packet, GET_LINK_STATUS_REQ_LEN, GET_LINK_STATUS_PKT_SIZE)

    local req_data = ncsi_utils.ncsi_packet_bs:pack(req_packet)
    return ncsi_protocol_intf.send_ncsi_cmd(req_data, 
    GET_LINK_STATUS_PKT_SIZE + ncsi_def.ETHERNET_HEAD_LEN + ncsi_def.PACKET_HEAD_LEN, eth_name)
end

local function save_link_status(eth_num, channel_id, pst_packet_rsp)
    local ncsi_para = ncsi_parameter.get_instance():get_ncsi_parameter()
    -- 构建link_status对象
    local link_status = {
        reserved3 = pst_packet_rsp.lnk_sts_reserved3,
        tx_flow_control = pst_packet_rsp.lnk_sts_tx_flow_control,
        rx_flow_control = pst_packet_rsp.lnk_sts_rx_flow_control,
        link_partner8 = pst_packet_rsp.lnk_sts_link_partner8,
        serdes_link = pst_packet_rsp.lnk_sts_serdes_link,
        oem_link_speed = pst_packet_rsp.lnk_sts_oem_link_speed,
        reserved2 = pst_packet_rsp.lnk_sts_reserved2,
        channel_available = pst_packet_rsp.lnk_sts_channel_available,
        link_partner1 = pst_packet_rsp.lnk_sts_link_partner1,
        link_partner2 = pst_packet_rsp.lnk_sts_link_partner2,
        link_partner3 = pst_packet_rsp.lnk_sts_link_partner3,
        link_partner4 = pst_packet_rsp.lnk_sts_link_partner4,
        link_partner5 = pst_packet_rsp.lnk_sts_link_partner5,
        link_partner6 = pst_packet_rsp.lnk_sts_link_partner6,
        link_partner7 = pst_packet_rsp.lnk_sts_link_partner7,
        link_flag = pst_packet_rsp.lnk_sts_link_flag,
        speed_duplex = pst_packet_rsp.lnk_sts_speed_duplex,
        negotiate_flag = pst_packet_rsp.lnk_sts_negotiate_flag,
        negotiate_complete = pst_packet_rsp.lnk_sts_negotiate_complete,
        parallel_detection = pst_packet_rsp.lnk_sts_parallel_detection
    }
    ncsi_para.link_status[eth_num][channel_id] = link_status
    ncsi_para.oem_link_status[eth_num][channel_id] = pst_packet_rsp.oem_link_status
    log:info("get_link_status_rsp: ncsi cmd completion code = %d, " ..
         "link status = 0x%X, " ..
         "speed duplex = 0x%04X, " ..
         "negotiate_flag = 0x%X, " ..
         "negotiate_complete = 0x%X, " ..
         "parallel_detection = 0x%X, " ..
         "oem = %u",
         link_status.link_flag,
         link_status.speed_duplex,
         link_status.negotiate_flag,
         link_status.negotiate_complete,
         link_status.parallel_detection,
         pst_packet_rsp.oem_link_status)
end

-- 读取响应
local function read_get_link_status_rsp(rsp, eth_id)
    if (rsp.packet_head.payload_len_lo | (rsp.packet_head.payload_len_hi << 8)) ~= GET_LINK_STATUS_RSP_LEN then
        log:error('Invalid response length')
        return ncsi_def.NCSI_FAIL
    end

    local data = get_link_status_rsp_bs:unpack(rsp.payload, true)
    if not data then
        log:error('Failed to unpack get link status response payload')
        return ncsi_def.NCSI_FAIL
    end
    local check_sum = data.check_sum
    -- check_sum为0时不校验
    if check_sum ~= 0 then
        local tmp_check_sum = ncsi_utils.get_checksum(rsp, ncsi_def.PACKET_HEAD_LEN + GET_LINK_STATUS_RSP_LEN)
        if check_sum ~= core.htonl(tmp_check_sum) then
            return ncsi_def.NCSI_FAIL
        end
    end
    local rsp_code = data.rsp_code
    if rsp_code == ncsi_def.CMD_COMPLETED then
        save_link_status(eth_id, rsp.packet_head.channel_id, data)
        return ncsi_def.NCSI_SUCCESS
    end
    ncsi_utils.common_respcode_parse(rsp_code)
    if data.reason_code == LINK_ACCESS_ERROR then
        log:error('Link Command Failed-Hardware Access Error')
    else
        ncsi_utils.common_reasoncode_parse(data.reason_code)
    end
    return ncsi_def.NCSI_FAIL
end

-- 命令处理表
local get_status_table = {
    [GET_LINK_STATUS] = write_get_link_status_req,
    [GET_LINK_STATUS_RSP] = read_get_link_status_rsp
}

function ncsi_get_status.ncsi_get_link_status(package_id, channel_id, eth_name, eth_id)
    local req_packet = ncsi_packet.create_request_packet(package_id, channel_id, GET_LINK_STATUS)

    -- 创建自定义命令处理表，替换原始的读响应函数
    local custom_cmd_table = {}
    for k, v in pairs(get_status_table) do
        custom_cmd_table[k] = v
    end
    custom_cmd_table[GET_LINK_STATUS_RSP] = function(rsp)
        return read_get_link_status_rsp(rsp, eth_id)
    end

    local ret = ncsi_utils.ncsi_cmd_ctrl(package_id, channel_id, req_packet, eth_name, custom_cmd_table)
    if ret ~= ncsi_def.NCSI_SUCCESS then
        log:error('ncsi cmd ctrl get link status failed, package_id = %s, channel_id = %s, eth_name = %s',
            package_id, channel_id, eth_name)
    end
    return ret
end

return ncsi_get_status