-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.


local log = require 'mc.logging'
local class = require 'mc.class'
local crc8 = require 'mc.crc8'
local utils = require 'mc.utils'
local file_sec = require 'utils.file'
local bs = require 'mc.bitstring'
local context = require 'mc.context'
local _, skynet = pcall(require, 'skynet')
local object_pool = require 'object_pool'
local cmds = {}
local HEADER_LEN<const> = 12
local BUFFER_LEN<const> = 32
local STD_SMBUS_REQ_COMMAND_CODE <const> = 0x20 -- smbus读/写请求命令码
local STD_SMBUS_RSP_COMMAND_CODE <const> = 0x21 -- smbus读/写响应命令码

local request_header_bs <const> = bs.new([[<<
    lun,
    arg,
    opcode:16,
    offset:32,
    length:32,
    data/string >>]])

local response_data_bs <const> = bs.new([[<<
    error_code:16,
    opcode:16,
    total_length:32,
    length:32,
    data/string >>]])

local success_code <const> = 0
local error_code <const> = {
    [0x1] = 'opcode not support',
    [0x2] = 'parameter error',
    [0x3] = 'internal error',
    [0xF] = 'bus busy'
}
local pre_loss_code <const> = 0x13
local cur_loss_code <const> = 0xff
local DEFAULT_MASK<const> = 0xffffffff
local BLOCK_ACCESS_TYPE<const> = 1

local FIRMWARE_OPCODE = {
    ['DPUCpld'] = 0x1048,
    ['DPUVrd'] = 0x1051
}

local function check_data(check_sum, data)
    local crc = crc8(data)
    if crc ~= check_sum then
        log:debug('get crc: %s, real crc: %s', check_sum, crc)
        return false
    end
    return true
end

local function chip_write(chip, cmd, data)
    if log:getLevel() >= log.DEBUG then
        log:debug("write cmd:0x%x data:%s", cmd, utils.to_hex(data))
    end
    local input = object_pool.new('input_args', cmd, DEFAULT_MASK, BLOCK_ACCESS_TYPE, #data)
    chip:write(input, data)
end

local function chip_read(chip, cmd, len)
    log:debug("plugins chip read cmd%s   len%s", cmd, len)
    local input = object_pool.new('input_args', cmd, DEFAULT_MASK, BLOCK_ACCESS_TYPE, len, nil)
    return chip:read(input)
end

local function combo_write_read(chip, write_cmd, indata, read_cmd, read_len)
    chip_write(chip, write_cmd, indata)
    return chip_read(chip, read_cmd, read_len)
end

local function chip_write_read(chip, addr, write_cmd, data, read_cmd, read_len)
    -- write read data
    local check_buf = table.concat({ string.char(addr), string.char(write_cmd), data })
    local crc = crc8(check_buf)

    local ok, value = pcall(combo_write_read, chip, write_cmd, data .. string.pack('B', crc), read_cmd, read_len)
    if not ok or not value then
        log:error("combo_write_read error: %s", value)
        error(table.concat({ '[smbus]read commad(0x', string.format('%02x', read_cmd), ') failed' }))
    end

    -- check_data
    check_buf = table.concat({ string.char(addr), string.char(read_cmd), string.char(addr | 0x01),
        value:sub(1, #value - 1) })
    if not value or not check_data(value:sub(#value, #value):byte(), check_buf) then
        error(table.concat({ '[smbus]read commad(0x', string.format('%02x', read_cmd), ') failed' }))
    end
    return value
end

local function chip_blkwrite_read(chip, addr, write_cmd, data, read_cmd, read_len)
    local value = chip_write_read(chip, addr, write_cmd, string.pack('B', #data) .. data, read_cmd, read_len + 2)
    return value:sub(2, #value - 1)
end

local function send_and_receive(chip, head, read_len, slave_address)
    local ok, recv_data = pcall(function()
        local recv_data = response_data_bs:unpack(chip_blkwrite_read(chip, slave_address,
            STD_SMBUS_REQ_COMMAND_CODE, request_header_bs:pack(head), STD_SMBUS_RSP_COMMAND_CODE,
             read_len + HEADER_LEN))
        return recv_data
    end)
    if not ok then
        log:error("send and receive fail: %s", recv_data)
        return false
    end
    return recv_data
end

local function process_error_code(req, offset, error_code, cycle)
    local is_continue = true
    if error_code == pre_loss_code and cycle == 1 then -- 前一帧丢失，重新发送前一帧，cpld升级不考虑多帧丢失，一帧最多重试100次
        offset = offset - req.length
        req.offset = offset
    elseif error_code == cur_loss_code then
        skynet.sleep(50) -- 当前帧重发代表MCU擦除外挂flash中，秒级等待
    else
        log:error("[DPU] start to rerun Upgrade Cpld, error_code: %s", error_code)
        is_continue = false
    end

    if cycle == 100 and error_code ~= success_code then
        log:error("[DPU] retry up to 100 rerun Upgrade Cpld, error code: %s", error_code)
        is_continue = false
    end

    return is_continue, offset
end

local function upgrade_fw(chip, fw_path, firmware_type, buffer_len, slave_address, offset, max_send_len)
    log:notice("[DPU] start to upgrade %s with plugins offset: %s", firmware_type, offset)

    local fp = file_sec.open_s(fw_path, 'rb')
    if not fp then
        log:error('[DPU] open firmware file fail')
        return false
    end
    local fw_stream = utils.close(fp, pcall(fp.read, fp, 'a'))
    local fw_stream_len = #fw_stream
    local max_payload_size = buffer_len - HEADER_LEN
    local send_offset = offset
    local is_continue
    local req = {
        lun = 0,
        arg = 0,
        opcode = FIRMWARE_OPCODE[firmware_type],
    }

    local recv_data
    while offset < fw_stream_len and max_send_len >= offset - send_offset do
        req.offset = offset
        if offset + max_payload_size >= fw_stream_len then
            req.lun = 0x80
            req.length = fw_stream_len - offset
        else
            req.lun = 0
            req.length = max_payload_size
        end
        req.data = fw_stream:sub(offset + 1, offset + req.length)
        for cycle = 1, 100 do
            recv_data = send_and_receive(chip, req, buffer_len - HEADER_LEN, slave_address)
            if not recv_data then
                goto continue
            end
            
            if recv_data.error_code == success_code then
                offset = offset + req.length
                break
            end

            is_continue, offset = process_error_code(req, offset, recv_data.error_code, cycle)
            if not is_continue then
                return false
            end

            ::continue::
        end
    end

    return true
end

function cmds.upgrade(chip, fw_path, firmware_type, buffer_len, slave_address, offset, max_send_len)
    local ok, status = pcall(upgrade_fw, chip, fw_path, firmware_type, buffer_len, slave_address, offset, max_send_len)
    if not ok or not status then
        log:error("[DPU] upgrade fail, status: %s", status)
        return false
    end
    return true
end

local function change_Wp(self_chip, obj)
    if not obj.value then
        obj.fetch.result = false
        obj.fetch.cb()
        return
    end
    local wp = obj.value:sub(1,1):byte()
    local data =  '\x03' .. string.char(wp)
    local req = {
        lun = 0x80,
        arg = 0x00,
        opcode = 0x1055,
        offset = 0x00,
        length =  #data,
        data = data
    }
    local retry_times = 5
    local recv_data
    for _ = 1, retry_times do
        recv_data = send_and_receive(self_chip, req, 20, 0xD4)
        if not recv_data or recv_data.error_code ~= success_code then
            log:error('[DPU] change write protect failed, value :%s', wp)
            obj.fetch.result = false
            skynet.sleep(100)
        else
            log:notice('[DPU]change write protect to %s successfully', wp)
            obj.fetch.result = true
            break
        end
    end
    obj.fetch.cb()
end

function cmds.rewrite_protect_chip(chip)
    -- 框架会合并相同address的chip,配置CSR时将address和read_retry_times倒换,代码中再swap,防止chip对象合并
    local tmp = chip.driver.address
    chip.driver.address = chip.read_retry_times
    chip.read_retry_times = tmp
    chip.write_once = change_Wp
    log:notice('[DPU]init write protect chip successfully')
end

local general_hardware_plugins = class()

function general_hardware_plugins:ctor()
    log:notice('[general_hardware_plugins] ctor')
end

function general_hardware_plugins:has_cmd(cmd_name)
    return cmds[cmd_name] ~= nil
end

function general_hardware_plugins:run_cmd(chip, cmd, ...)
    log:debug('[general_hardware_plugins] run cmd[%s]', cmd)
    return cmds[cmd](chip, ...)
end

return general_hardware_plugins