-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.


local log = require 'mc.logging'
local class = require 'mc.class'
local crc8 = require 'mc.crc8'
local utils = require 'mc.utils'
local file_sec = require 'utils.file'
local bs = require 'mc.bitstring'
local context = require 'mc.context'
local _, skynet = pcall(require, 'skynet')
local object_pool = require 'object_pool'
local cmds = {}
local HEADER_LEN<const> = 12
local BUFFER_LEN<const> = 32
local STD_SMBUS_REQ_COMMAND_CODE <const> = 0x20 -- smbus读/写请求命令码
local STD_SMBUS_RSP_COMMAND_CODE <const> = 0x21 -- smbus读/写响应命令码
local I2C_SET_SPI = 0x01
local I2C_UPGRADE_CMD = 0x02

local request_header_bs <const> = bs.new([[<<
    lun,
    arg,
    opcode:16,
    offset:32,
    length:32,
    data/string >>]])

local response_data_bs <const> = bs.new([[<<
    error_code:16,
    opcode:16,
    total_length:32,
    length:32,
    data/string >>]])

local success_code <const> = 0
local error_code <const> = {
    [0x1] = 'opcode not support',
    [0x2] = 'parameter error',
    [0x3] = 'internal error',
    [0xF] = 'bus busy'
}
local pre_loss_code <const> = 0x13
local cur_loss_code <const> = 0xff
local DEFAULT_MASK<const> = 0xffffffff
local BLOCK_ACCESS_TYPE<const> = 1

local FIRMWARE_OPCODE = {
    ['DPUCpld'] = 0x1048,
    ['DPUVrd'] = 0x1051
}

local function check_data(check_sum, data)
    local crc = crc8(data)
    if crc ~= check_sum then
        log:debug('get crc: %s, real crc: %s', check_sum, crc)
        return false
    end
    return true
end

local function chip_write(chip, cmd, data)
    if log:getLevel() >= log.DEBUG then
        log:debug("write cmd:0x%x data:%s", cmd, utils.to_hex(data))
    end
    local input = object_pool.new('input_args', cmd, DEFAULT_MASK, BLOCK_ACCESS_TYPE, #data)
    chip:write(input, data)
end

local function chip_read(chip, cmd, len)
    log:debug("plugins chip read cmd%s   len%s", cmd, len)
    local input = object_pool.new('input_args', cmd, DEFAULT_MASK, BLOCK_ACCESS_TYPE, len, nil)
    return chip:read(input)
end

local function chip_write_wp(chip, access, cmd, data)
    local t_driver = chip.driver
    local input = object_pool.new('input', access.address or t_driver.address, t_driver.addr_width,
        t_driver.offset_width, cmd, 4, DEFAULT_MASK, BLOCK_ACCESS_TYPE, data, access.protocol_flag, chip.name,
        access.is_trace or chip.ismonitored, access.requestor or '')
    chip.driver:write(input)
    object_pool.recycle('input', input)
end

local function chip_read_wp(chip, access, cmd, len)
    local t_driver = chip.driver
    local input = object_pool.new('input', access.address or t_driver.address, t_driver.addr_width,
        t_driver.offset_width, cmd, len, DEFAULT_MASK, BLOCK_ACCESS_TYPE, access.data_in, access.protocol_flag,
        chip.name, access.is_trace or chip.ismonitored, access.requestor or '')
    local ret = chip.driver:read(input)
    object_pool.recycle('input', input)
    return ret
end

local function combo_write_read(chip, write_cmd, indata, read_cmd, read_len)
    chip_write(chip, write_cmd, indata)
    return chip_read(chip, read_cmd, read_len)
end

local function combo_write_read_wp(chip, access, write_cmd, indata, read_cmd, read_len)
    chip_write_wp(chip, access, write_cmd, indata)
    return chip_read_wp(chip, access, read_cmd, read_len)
end

local function chip_write_read(chip, addr, write_cmd, data, read_cmd, read_len)
    -- write read data
    local check_buf = table.concat({ string.char(addr), string.char(write_cmd), data })
    local crc = crc8(check_buf)

    local ok, value = pcall(combo_write_read, chip, write_cmd, data .. string.pack('B', crc), read_cmd, read_len)
    if not ok or not value then
        log:error("combo_write_read error: %s", value)
        error(table.concat({ '[smbus]read commad(0x', string.format('%02x', read_cmd), ') failed' }))
    end

    -- check_data
    check_buf = table.concat({ string.char(addr), string.char(read_cmd), string.char(addr | 0x01),
        value:sub(1, #value - 1) })
    if not value or not check_data(value:sub(#value, #value):byte(), check_buf) then
        error(table.concat({ '[smbus]read commad(0x', string.format('%02x', read_cmd), ') failed' }))
    end
    return value
end

local function chip_write_read_wp(chip, access, addr, write_cmd, data, read_cmd, read_len)
    -- write read data
    local check_buf = table.concat({ string.char(addr), string.char(write_cmd), data })
    local crc = crc8(check_buf)

    local ok, value = pcall(combo_write_read_wp, chip, access, write_cmd, data .. string.pack('B', crc), read_cmd,
        read_len)
    if not ok or not value then
        log:error("combo_write_read error: %s", value)
        error(table.concat({ '[smbus]read commad(0x', string.format('%02x', read_cmd), ') failed' }))
    end

    -- check_data
    check_buf = table.concat({ string.char(addr), string.char(read_cmd), string.char(addr | 0x01),
        value:sub(1, #value - 1) })
    if not value or not check_data(value:sub(#value, #value):byte(), check_buf) then
        error(table.concat({ '[smbus]read commad(0x', string.format('%02x', read_cmd), ') failed' }))
    end
    return value
end

local function chip_blkwrite_read(chip, addr, write_cmd, data, read_cmd, read_len)
    local value = chip_write_read(chip, addr, write_cmd, string.pack('B', #data) .. data, read_cmd, read_len + 2)
    return value:sub(2, #value - 1)
end

local function chip_blkwrite_read_wp(chip, access, addr, write_cmd, data, read_cmd, read_len)
    local value = chip_write_read_wp(chip, access, addr, write_cmd, string.pack('B', #data) .. data, read_cmd,
        read_len + 2)
    return value:sub(2, #value - 1)
end

local function send_and_receive(chip, head, read_len, slave_address)
    local ok, recv_data = pcall(function()
        local recv_data = response_data_bs:unpack(chip_blkwrite_read(chip, slave_address,
            STD_SMBUS_REQ_COMMAND_CODE, request_header_bs:pack(head), STD_SMBUS_RSP_COMMAND_CODE,
             read_len + HEADER_LEN))
        return recv_data
    end)
    if not ok then
        log:error("send and receive fail: %s", recv_data)
        return false
    end
    return recv_data
end

local function send_and_receive_wp(chip, access, head, read_len, slave_address)
    local ok, recv_data = pcall(function()
        local recv_data = response_data_bs:unpack(chip_blkwrite_read_wp(chip, access, slave_address,
            STD_SMBUS_REQ_COMMAND_CODE, request_header_bs:pack(head), STD_SMBUS_RSP_COMMAND_CODE,
            read_len + HEADER_LEN))
        return recv_data
    end)
    if not ok then
        log:error("send and receive fail: %s", recv_data)
        return false
    end
    return recv_data
end

local function process_error_code(req, offset, error_code, cycle)
    local is_continue = true
    if error_code == pre_loss_code and cycle == 1 then -- 前一帧丢失，重新发送前一帧，cpld升级不考虑多帧丢失，一帧最多重试100次
        offset = offset - req.length
        req.offset = offset
    elseif error_code == cur_loss_code then
        skynet.sleep(50) -- 当前帧重发代表MCU擦除外挂flash中，秒级等待
    else
        log:error("[DPU] start to rerun Upgrade Cpld, error_code: %s", error_code)
        is_continue = false
    end

    if cycle == 100 and error_code ~= success_code then
        log:error("[DPU] retry up to 100 rerun Upgrade Cpld, error code: %s", error_code)
        is_continue = false
    end

    return is_continue, offset
end

local function upgrade_fw(chip, fw_path, firmware_type, buffer_len, slave_address, offset, max_send_len, cpld_id)
    log:notice("[DPU] CPLD%s start to upgrade %s with plugins offset: %s", cpld_id, firmware_type, offset)

    local arg = 0
    if cpld_id and cpld_id >= 1 then --防止其他固件不传cpld_id导致nul的情况
        arg = cpld_id - 1
    end
    local fp = file_sec.open_s(fw_path, 'rb')
    if not fp then
        log:error('[DPU] CPLD%s open firmware file fail', cpld_id)
        return false
    end
    local fw_stream = utils.close(fp, pcall(fp.read, fp, 'a'))
    local fw_stream_len = #fw_stream
    local max_payload_size = buffer_len - HEADER_LEN
    local send_offset = offset
    local is_continue
    local req = {
        lun = 0,
        arg = arg,
        arg = cpld_id - 1,
        opcode = FIRMWARE_OPCODE[firmware_type],
    }

    local recv_data
    while offset < fw_stream_len and max_send_len >= offset - send_offset do
        req.offset = offset
        if offset + max_payload_size >= fw_stream_len then
            req.lun = 0x80
            req.length = fw_stream_len - offset
        else
            req.lun = 0
            req.length = max_payload_size
        end
        req.data = fw_stream:sub(offset + 1, offset + req.length)
        for cycle = 1, 100 do
            recv_data = send_and_receive(chip, req, buffer_len - HEADER_LEN, slave_address)
            if not recv_data then
                goto continue
            end
            
            if recv_data.error_code == success_code then
                offset = offset + req.length
                break
            end

            is_continue, offset = process_error_code(req, offset, recv_data.error_code, cycle)
            if not is_continue then
                return false
            end

            ::continue::
        end
    end

    return true
end

function cmds.upgrade(chip, fw_path, firmware_type, buffer_len, slave_address, offset, max_send_len, cpld_id)
    local ok, status = pcall(upgrade_fw, chip, fw_path, firmware_type, buffer_len, slave_address,
        offset, max_send_len, cpld_id)
    if not ok or not status then
        log:error("[DPU] upgrade fail, status: %s", status)
        return false
    end
    return true
end

local function change_Wp(self_chip, access, buffer)
    if not buffer then
        log:error('Parameter error, buffer is nil')

        return
    end
    local wp = buffer:sub(1, 1):byte()
    local data = '\x03' .. string.char(wp)
    local req = {
        lun = 0x80,
        arg = 0x00,
        opcode = 0x1055,
        offset = 0x00,
        length = #data,
        data = data
    }
    local retry_times = 5
    local recv_data
    for _ = 1, retry_times do
        recv_data = send_and_receive_wp(self_chip, access, req, 20, 0xD4)
        if not recv_data or recv_data.error_code ~= success_code then
            log:error('[DPU] change write protect failed, value :%s', wp)

            skynet.sleep(100)
        else
            log:notice('[DPU]change write protect to %s successfully', wp)

            break
        end
    end
end

function cmds.rewrite_protect_chip(chip)
    -- 框架会合并相同address的chip,配置CSR时将address和read_retry_times倒换,代码中再swap,防止chip对象合并
    local tmp = chip.driver.address
    chip.driver.address = chip.read_retry_times
    chip.read_retry_times = tmp
    chip.write = change_Wp
    log:notice('[DPU]init write protect chip successfully')
end

function cmds.retimer_smbus_write_block(chip, use_pec, send_data, write_cmd)
    if use_pec then
        send_data = send_data .. string.char(
            crc8(string.char(chip.driver.address) .. string.char(write_cmd) .. send_data)
        )
    end
    return pcall(chip_write, chip, write_cmd, send_data)
end

function cmds.retimer_smbus_read_block(chip, use_pec, send_data, write_cmd, read_cmd, read_length)
    send_data = use_pec and send_data .. string.char(
        crc8(string.char(chip.driver.address) .. string.char(write_cmd) .. send_data)
    ) or send_data
    local ok1 = pcall(chip_write, chip, write_cmd, send_data)
    if not ok1 then
        log:error("retimer_smbus_read_block chip_write fail, send_data: %s", utils.to_hex(send_data))
        return false, nil
    end

    local ok, msg = pcall(chip_read, chip, read_cmd, read_length)
    if not ok then 
        log:error('retimer_smbus_read_block chip_read fail, error: %s', msg)
        return false, nil
    end

    if use_pec then
        local check_buf = string.char(chip.driver.address) .. string.char(read_cmd) ..
            string.char(chip.driver.address + 1) .. msg:sub(1, #msg - 1)
        if not check_data(string.byte(msg:sub(#msg)),check_buf) then
            log:error('PEC check fail %s, receive pec is %s, calculate crc is %s',
                utils.to_hex(check_buf),
                utils.to_hex(msg:sub(#msg)),
                crc8(check_buf)
            )
            return false, nil
        end
        msg = msg:sub(1, #msg - 1)
    end
    return true,msg
end

local function retry_chip_write(wait, retries, chip, cmd, data)
    local ok = false
    local rsp
    for _ = 1, retries do
        ok, rsp = pcall(function ()
            return chip_write(chip, cmd, data)
        end)
        if ok then
            return ok, rsp
        end
        skynet.sleep(wait)
    end
    return ok, rsp
end

local function set_spi_mode(chip)
    local data = '\xf0'
    return retry_chip_write(10, 10, chip, I2C_SET_SPI, data)
end

local function set_by_pass(chip)
    local data = '\x06\x00'
    return retry_chip_write(10, 10, chip, I2C_UPGRADE_CMD, data)
end

local function retry_chip_write_by_pass(wait, retries, chip, cmd, data)
    local ok = false
    local rsp
    for _ = 1, retries do
        ok, rsp = pcall(function ()
            return chip_write(chip, cmd, data)
        end)
        if ok then
            return ok, rsp
        end
        skynet.sleep(wait)
        log:error('[CPLD]write data faild, error: %s', rsp)
        set_by_pass(chip) --写失败后需要添加写使能，否则后续重试会写入失败
    end
    return ok, rsp
end

-- I2C文件传输数据域格式
local smc_transfer_file_data = bs.new([[<<
    flag:8,
    offset:24/big,
    data/string
>>]])

function cmds.iic_upgrade_anlu(chip, upg_file_path, max_write_len, delay, written_len)
    log:notice('Upgrade start, max_write_len is %s, delay is %s, written_len is %s',
        max_write_len, delay, written_len)
    local file, error = file_sec.open_s(upg_file_path, 'rb')
    if not file then
        log:error('upg_file not exist, error:%s', error)
        return false
    end
    local data = utils.close(file, pcall(file.read, file, '*a'))
    local file_len = #data
    -- 写入之前重新设置SPI模式
    set_spi_mode(chip)
    local offset = 0

    -- 文件写入为0~0x5ffff,小于则文件有问题
    if file_len < max_write_len then
        log:error('[CPLD]data is too small to write flash, length:%s', file_len)
        return false
    end
    local str, pack_data, ok, err
    
    while offset < max_write_len do
        str = string.sub(data, offset + 1, offset + written_len)
        pack_data = smc_transfer_file_data:pack({
            flag = 2,
            offset = offset,
            data = str
        }) .. '\x00'
        set_by_pass(chip)
        ok, err = retry_chip_write_by_pass(3, 3, chip, I2C_UPGRADE_CMD, pack_data)
        skynet.sleep(delay)
        if not ok then
            log:error('[CPLD]write flash(%s) failed, error: %s', offset, err)
            return false
        end
        offset = offset + written_len
    end
    return true
end

local MAX_RETRY = 10
local PMBUS_READ_HEAD_LEN = 3
function cmds.pmbus_batch_read(chip, batch_reads)
    local results = {}

    for i, req in ipairs(batch_reads) do
        -- 写地址命令：计算CRC并发送
        local write_head = string.char(req.addr, req.write_cmd)
        local write_data_with_crc = req.write_data .. string.char(crc8(write_head .. req.write_data))

        local ok = pcall(chip_write, chip, req.write_cmd, write_data_with_crc)
        if not ok then
            log:error('[pmbus_batch_read] Write cmd:0x%02X failed at batch index %d', req.write_cmd, i)
            return false, string.format('write failed at index %d', i)
        end

        -- 读取数据命令：重试读取并验证CRC
        local read_head = string.char(req.addr, req.read_cmd, req.addr | 0x01)
        local read_value = nil

        for retry = 1, MAX_RETRY do
            local ok2, payload = pcall(chip_read, chip, req.read_cmd, req.read_total_len)

            if ok2 and payload and #payload >= req.read_total_len then
                -- 验证CRC：head(3字节) + payload数据部分 + PEC(1字节)
                local full_stream = read_head .. payload
                local crc_range = string.sub(full_stream, 1, req.data_len + PMBUS_READ_HEAD_LEN)
                local actual_crc = crc8(crc_range)
                local expected_crc = string.byte(full_stream, req.data_len + PMBUS_READ_HEAD_LEN + 1)

                if expected_crc == actual_crc then
                    -- CRC校验成功，保存数据
                    read_value = string.sub(payload, 1, req.data_len)
                    break
                end
            end
        end

        if not read_value then
            log:error('[pmbus_batch_read] Read cmd:0x%02X failed at batch index %d after %d retries',
                     req.read_cmd, i, MAX_RETRY)
            return false, string.format('read failed at index %d', i)
        end

        results[#results + 1] = read_value
    end

    return true, results
end

local general_hardware_plugins = class()

function general_hardware_plugins:ctor()
    log:notice('[general_hardware_plugins] ctor')
end

function general_hardware_plugins:has_cmd(cmd_name)
    return cmds[cmd_name] ~= nil
end

function general_hardware_plugins:run_cmd(chip, cmd, ...)
    log:debug('[general_hardware_plugins] run cmd[%s]', cmd)
    return cmds[cmd](chip, ...)
end

return general_hardware_plugins