-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local bs = require 'mc.bitstring'
local log = require 'mc.logging'
local core = require 'network.core'
local ncsi_protocol_intf = require 'ncsi_protocol_intf'
local ncsi_utils = require 'ncsi.ncsi_protocol.ncsi_utils'
local ncsi_def = require 'ncsi.ncsi_protocol.ncsi_def'
local ncsi_packet = require 'ncsi.ncsi_protocol.ncsi_packet'
local ncsi_parameter = require 'ncsi.ncsi_protocol.ncsi_parameter'

local ncsi_capabilities = {}

local GET_CAPABILITIES = 0x16
local GET_CAPABILITIES_RSP = 0x96

local CAPABILITIES_PKT_SIZE = 34
local CAPABILITIES_REQ_LEN = 0
local CAPABILITIES_RSP_LEN = 32
local NCSI_PACKAGE_MAX_ID = 8
local NCSI_CHANNEL_MAX_ID = 4

local capabilities_req_bs = bs.new([[<<
    check_sum:32,
    data:26/string,
    fcs:32
>>]])

local capabilities_rsp_bs = bs.new([[<<
    rsp_code:16,
    reason_code:16,
    capality_flag:32,
    broadcast_filter_cap:32,
    multicast_filter_cap:32,
    buffer_cap:32,
    aen_support:32,
    vlan_filter_cnt:8,
    mix_filter_cnt:8,
    mul_filter_cnt:8,
    unicast_filter_cnt:8,
    reserved:2/string,
    vlan_mode:8,
    channel_cnt:8,
    check_sum:32,
    fcs:32
>>]])

local function write_capabilities_req(req_packet, eth_name)
    ncsi_utils.ncsi_cmd_common_config(req_packet)
    req_packet.packet_head.payload_len_hi = (CAPABILITIES_REQ_LEN >> 8) & 0x0f
    req_packet.packet_head.payload_len_lo = CAPABILITIES_REQ_LEN & 0xff

    local check_sum = ncsi_utils.get_checksum(req_packet, ncsi_def.PACKET_HEAD_LEN + CAPABILITIES_REQ_LEN)
    local crc32 = ncsi_utils.get_crc32(req_packet, ncsi_def.PACKET_HEAD_LEN + CAPABILITIES_PKT_SIZE - 4)

    local payload_data = capabilities_req_bs:pack({
        check_sum = core.htonl(check_sum),
        data = '',
        fcs = core.htonl(crc32)
    })

    req_packet.payload = payload_data
    local req_data = ncsi_utils.ncsi_packet_bs:pack(req_packet)
    return ncsi_protocol_intf.send_ncsi_cmd(req_data,
        CAPABILITIES_PKT_SIZE + ncsi_def.ETHERNET_HEAD_LEN + ncsi_def.PACKET_HEAD_LEN, eth_name)
end

local function read_capabilities_rsp(rsp)
    if (rsp.packet_head.payload_len_lo | (rsp.packet_head.payload_len_hi << 8)) ~= CAPABILITIES_RSP_LEN then
        return ncsi_def.NCSI_FAIL
    end

    local data = capabilities_rsp_bs:unpack(rsp.payload, true)
    if not data then
        log:error('Failed to unpack capabilities response payload')
        return ncsi_def.NCSI_FAIL
    end
    local check_sum = data.check_sum
    if check_sum ~= 0 then
        local tmp_check_sum = ncsi_utils.get_checksum(rsp, ncsi_def.PACKET_HEAD_LEN + CAPABILITIES_RSP_LEN)
        if check_sum ~= core.htonl(tmp_check_sum) then
            return ncsi_def.NCSI_FAIL
        end
    end

    local rsp_code = data.rsp_code
    local ncsi_para = ncsi_parameter.get_instance():get_ncsi_parameter()
    if rsp_code == ncsi_def.NCSI_SUCCESS then
        local row = rsp.packet_head.package_id
        local col = rsp.packet_head.channel_id
        if row <= NCSI_PACKAGE_MAX_ID - 1 and col <= NCSI_CHANNEL_MAX_ID - 1 then
            ncsi_para.channel_cap[row][col].capbility_flag = data.capality_flag
            ncsi_para.channel_cap[row][col].broadcast_filter_capality = data.broadcast_filter_cap
            ncsi_para.channel_cap[row][col].multicast_filter_capality = data.multicast_filter_cap
            ncsi_para.channel_cap[row][col].buffer_capality = data.buffer_cap
            ncsi_para.channel_cap[row][col].aen_support = data.aen_support
            ncsi_para.channel_cap[row][col].vlan_filter_cnt = data.vlan_filter_cnt
            ncsi_para.channel_cap[row][col].mix_filter_cnt = data.mix_filter_cnt
            ncsi_para.channel_cap[row][col].mul_filter_cnt = data.mul_filter_cnt
            ncsi_para.channel_cap[row][col].unicast_filter_cnt = data.unicast_filter_cnt
            ncsi_para.channel_cap[row][col].reserved = data.reserved
            ncsi_para.channel_cap[row][col].vlan_mode = data.vlan_mode
            ncsi_para.channel_cap[row][col].channel_cnt = data.channel_cnt
        else
            log:error('channel id and package id are out of range')
            return ncsi_def.NCSI_FAIL
        end
        ncsi_para.multicast_filter_cap = data.multicast_filter_cap
        ncsi_para.channel_cnt = data.channel_cnt
        return ncsi_def.NCSI_SUCCESS
    end

    ncsi_utils.common_respcode_parse(rsp_code)
    ncsi_utils.common_reasoncode_parse(data.reason_code)
    return ncsi_def.NCSI_FAIL
end

local ncsi_capabilities_table = {
    [GET_CAPABILITIES] = write_capabilities_req,
    [GET_CAPABILITIES_RSP] = read_capabilities_rsp
}

function ncsi_capabilities.ncsi_get_capabilities(package_id, channel_id, eth_name)
    local req_packet = ncsi_packet.create_request_packet(package_id, channel_id, GET_CAPABILITIES)
    local ret = ncsi_utils.ncsi_cmd_ctrl(package_id, channel_id, req_packet, eth_name, ncsi_capabilities_table)
    if ret ~= ncsi_def.NCSI_SUCCESS then
        log:error('ncsi cmd ctrl get capabilities failed, package_id = %s, channel_id = %s, eth_name = %s',
            package_id, channel_id, eth_name)
    end
    return ret
end

return ncsi_capabilities