-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local bs = require 'mc.bitstring'
local log = require 'mc.logging'
local core = require 'network.core'
local ncsi_protocol_intf = require 'ncsi_protocol_intf'
local ncsi_utils = require 'ncsi.ncsi_protocol.ncsi_utils'
local ncsi_def = require 'ncsi.ncsi_protocol.ncsi_def'
local ncsi_packet = require 'ncsi.ncsi_protocol.ncsi_packet'
local ncsi_parameter = require 'ncsi.ncsi_protocol.ncsi_parameter'

local ncsi_version = {}

local GET_VERSION_ID = 0x15
local GET_VERSION_ID_RSP = 0x95

local VERSION_PKT_SIZE = 34
local VERSION_REQ_LEN = 0
local VERSION_RSP_LEN = 40

local version_req_bs = bs.new([[<<
    check_sum:32,
    data:26/string,
    fcs:32
>>]])

local version_rsp_bs = bs.new([[<<
    rsp_code:16,
    reason_code:16,
    major_ver:8,
    minor_ver:8,
    update_ver:8,
    alpha1:8,
    reserved:3/string,
    alpha2:8,
    name_string:12/string,
    firmware_ver:4/string,
    pci_did:16,
    pci_vid:16,
    pci_ssid:16,
    pci_svid:16,
    manufacturer_id:32,
    check_sum:32,
    fcs:32
>>]])

local function write_version_req(req_packet, eth_name)
    ncsi_utils.ncsi_cmd_common_config(req_packet)
    req_packet.packet_head.payload_len_hi = (VERSION_REQ_LEN >> 8) & 0x0f
    req_packet.packet_head.payload_len_lo = VERSION_REQ_LEN & 0xff

    local check_sum = ncsi_utils.get_checksum(req_packet, ncsi_def.PACKET_HEAD_LEN + VERSION_REQ_LEN)
    local crc32 = ncsi_utils.get_crc32(req_packet, ncsi_def.PACKET_HEAD_LEN + VERSION_PKT_SIZE - 4)
    local payload_data = version_req_bs:pack({
        check_sum = core.htonl(check_sum),
        data = '',
        fcs = core.htonl(crc32)
    })

    req_packet.payload = payload_data
    local req_data = ncsi_utils.ncsi_packet_bs:pack(req_packet)
    return ncsi_protocol_intf.send_ncsi_cmd(req_data,
        VERSION_PKT_SIZE + ncsi_def.ETHERNET_HEAD_LEN + ncsi_def.PACKET_HEAD_LEN, eth_name)
end

local function bcd_code_convert(data)
    local i = (data & 0x0F) > 9 and 0 or (data & 0x0F)
    local j = ((data >> 4) & 0x0F) > 9 and 0 or ((data >> 4) & 0x0F)
    return j * 10 + i
end

local function short_by_big_endian(num)
    return (num >> 8 | num << 8) & 0xffff
end

local function parse_get_version_id_rsp(pst_packet_rsp)
    local ncsi_para = ncsi_parameter.get_instance():get_ncsi_parameter()
    ncsi_para.ncsi_ver = string.format('%u.%u',
        bcd_code_convert(pst_packet_rsp.major_ver),
        bcd_code_convert(pst_packet_rsp.minor_ver))

    if pst_packet_rsp.update_ver ~= 0xff then
        ncsi_para.ncsi_ver = ncsi_para.ncsi_ver ..
            string.format(".%u", bcd_code_convert(pst_packet_rsp.update_ver))
    end

    if pst_packet_rsp.alpha1 ~= 0 then
        ncsi_para.ncsi_ver = ncsi_para.ncsi_ver .. string.format("%c", pst_packet_rsp.alpha1)
    end

    if pst_packet_rsp.alpha2 ~= 0 then
        ncsi_para.ncsi_ver = ncsi_para.ncsi_ver .. string.format("%c", pst_packet_rsp.alpha2)
    end

    ncsi_para.firmware_name = pst_packet_rsp.name_string
    ncsi_para.firmware_ver = string.format("%02X:%02X:%02X:%02X",
        string.byte(pst_packet_rsp.firmware_ver, 1),
        string.byte(pst_packet_rsp.firmware_ver, 2),
        string.byte(pst_packet_rsp.firmware_ver, 3),
        string.byte(pst_packet_rsp.firmware_ver, 4))
    ncsi_para.manufacture_id = pst_packet_rsp.manufacturer_id

    ncsi_para.pcie_device_ids.pci_did = short_by_big_endian(pst_packet_rsp.pci_did)
    ncsi_para.pcie_device_ids.pci_vid = short_by_big_endian(pst_packet_rsp.pci_vid)
    ncsi_para.pcie_device_ids.pci_ssid = short_by_big_endian(pst_packet_rsp.pci_ssid)
    ncsi_para.pcie_device_ids.pci_svid = short_by_big_endian(pst_packet_rsp.pci_svid)
    return ncsi_def.NCSI_SUCCESS
end

local function read_version_rsp(rsp)
    if (rsp.packet_head.payload_len_lo | (rsp.packet_head.payload_len_hi << 8)) ~= VERSION_RSP_LEN then
        return ncsi_def.NCSI_FAIL
    end

    local data = version_rsp_bs:unpack(rsp.payload, true)
    if not data then
        log:error('Failed to unpack version response payload')
        return ncsi_def.NCSI_FAIL
    end
    local check_sum = data.check_sum
    if check_sum ~= 0 then
        local tmp_check_sum = ncsi_utils.get_checksum(rsp, ncsi_def.PACKET_HEAD_LEN + VERSION_RSP_LEN)
        if check_sum ~= core.htonl(tmp_check_sum) then
            return ncsi_def.NCSI_FAIL
        end
    end

    local rsp_code = data.rsp_code
    if rsp_code == ncsi_def.CMD_COMPLETED then
        return parse_get_version_id_rsp(data)
    end

    ncsi_utils.common_respcode_parse(rsp_code)
    ncsi_utils.common_reasoncode_parse(data.reason_code)
    return ncsi_def.NCSI_FAIL
end

local ncsi_version_table = {
    [GET_VERSION_ID] = write_version_req,
    [GET_VERSION_ID_RSP] = read_version_rsp
}

function ncsi_version.ncsi_get_version_id(package_id, channel_id, eth_name)
    local req_packet = ncsi_packet.create_request_packet(package_id, channel_id, GET_VERSION_ID)

    local ret = ncsi_utils.ncsi_cmd_ctrl(package_id, channel_id, req_packet, eth_name, ncsi_version_table)
    if ret ~= ncsi_def.NCSI_SUCCESS then
        log:error('ncsi cmd ctrl get version failed, package_id = %s, channel_id = %s, eth_name = %s',
            package_id, channel_id, eth_name)
    end
    return ret
end

return ncsi_version