-- Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
--
-- this file licensed under the Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
-- You may obtain a copy of Mulan PSL v2 at: http://license.coscl.org.cn/MulanPSL2
--
-- THIS SOFTWARE IS PROVIDED ON AN \"AS IS\" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR
-- IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR
-- PURPOSE.
-- See the Mulan PSL v2 for more details.
--

local math = require 'math'
local class = require 'mc.class'
local protocol = require 'protocol_open.protocol.protocol'
local bs = require 'mc.bitstring'
local ctx = require 'mc.context'
local skynet = require 'skynet'
local skynet_queue = require 'skynet.queue'
local log = require 'mc.logging'
local utils = require 'mc.utils'
local crc8 = require 'mc.crc8'

local write_read_delay <const> = 1 -- 10毫秒
local header_len <const> = 12
local max_retry_count <const> = 3

local default_mcu_address <const> = 0xD4
local write_command_code <const> = 0x20
local read_command_code <const> = 0x21

local response_data_bs <const> = bs.new([[<<
    error_code:16,
    opcode:16,
    total_length:32,
    length:32,
    data:length/string
>>]])

local frame_template <const> = [[<<
    count:8,
    body:1/body
>>]]

local request_header_bs <const> = bs.new([[<<
    lun,
    arg,
    opcode:16,
    offset:32,
    length:32
>>]])

local request_params_template <const> = {
    opcode = true,
    offset = true,
    expect_data_len = true,
    arg = true,
    data = true,
    batch_write = true,
    data_object_index = true
}

local std_smbus = class(protocol)

local success_code <const> = 0
local error_code <const> = {
    [0x1] = 'opcode not support',
    [0x2] = 'parameter error',
    [0x3] = 'internal error',
    [0xF] = 'bus busy'
}

function std_smbus:init()
    return true
end

function std_smbus:get_max_frame_len()
    return self.buffer_len - header_len
end

function std_smbus:get_delay_time()
    return write_read_delay
end

-- 只检查是否含有request_params_template以外的参数
function std_smbus:validate_request_params(req)
    for key in pairs(req) do
        if not self.request_params_template[key] then
            return false
        end
    end
    return true
end

function std_smbus:construct_request_data(track_request)
    local write_data = track_request.data:sub(track_request.data_written + 1,
        track_request.data_written + track_request.len)

    local is_write_completed = (track_request.data_written + track_request.len) >=
        #track_request.data
    local request = { count = header_len + #write_data, body = { lun = 0 } }
    if is_write_completed then
        request.body.lun = 0x80
    end
    request.body.arg = track_request.arg
    request.body.lun = request.body.lun | (track_request.data_object_index or 0)
    request.body.opcode = track_request.opcode
    request.body.offset = track_request.offset_invalid and 0 or track_request.data_written
    request.body.length = track_request.len
    local result = self.request_bs:pack(request)
    return result .. write_data
end

function std_smbus:unpack_response_data(rsp_bin, _)
    local result = self.response_bs:unpack(rsp_bin, true)
    if not result then
        return false, 'unable to parse smbus response binary'
    end
    local return_crc_mcu_address = (self.mcu_address | 0x01) -- 计算返回crc8时第二个mcu地址需要或0x1
    -- 解析crc8
    local rsp_crc = rsp_bin:sub(#rsp_bin, #rsp_bin):byte()
    local rsp_bin_local = rsp_bin:sub(1, #rsp_bin - 1)
    local expect_crc = crc8(string.format('%c%c%c', 
        self.mcu_address, read_command_code, return_crc_mcu_address) .. rsp_bin_local)
    if rsp_crc ~= expect_crc then
        return false, 'crc check error on std smbus response'
    end
    return true, result.body
end

function std_smbus:send(offset, data)
    log:debug('sending offset: %s, data: %s', offset, utils.to_hex(data))
    return pcall(function()
        return self.ref_chip:Write(ctx.new(), offset, data)
    end)
end

function std_smbus:receive(offset, len)
    return pcall(function()
        return self.ref_chip:Read(ctx.new(), offset, len)
    end)
end

function std_smbus:send_and_receive(data, len)
    -- 需要添加crc8
    local crc = crc8(string.char(self.mcu_address) .. string.char(write_command_code) .. data)
    log:debug('sending std smbus write command, data: %s, crc: %s', utils.to_hex(data), crc)
    return pcall(function()
        return self.ref_chip:ComboWriteRead(ctx.new(), write_command_code, data .. string.char(crc),
            read_command_code, len)
    end)
end

function std_smbus:_unpack_request(data, track_request)
    log:debug('receiving data: %s', utils.to_hex(data))
    if string.sub(data, 1, 1) == '\xFF' then
        return false, 'Reuqest(' .. track_request.opcode .. ') invalid response from the chip: ' ..
            utils.to_hex(data)
    end

    local ok, result = self:unpack_response_data(data, track_request.expect_data_len)
    if not ok then
        return ok, result
    end

    if result.error_code ~= success_code then
        return false, 'receive error from the chip, error: ' ..
            (self.error_code[result.error_code] or 'undefined error code')
    end
    if result.opcode ~= track_request.opcode then
        return false,
            'received incorrect result from the chip! expected opcode' .. track_request.opcode ..
            ', actual: ' .. result.opcode
    end
    return ok, result
end

function std_smbus:calc_last_frame_len(track_request)
    if track_request.data_written > 0 and #track_request.data > 0 and
        (track_request.data_written + track_request.max_frame_len) >= #track_request.data then
        track_request.len = #track_request.data - track_request.data_written
    end
end

function std_smbus:send_and_receive_request_in_frames(track_request)
    -- 最后一帧只发剩余长度，其他帧发最大长度
    self:calc_last_frame_len(track_request)

    local req_bin = self:construct_request_data(track_request)
    if req_bin == '' then
        return false, 'unable to construct request data'
    end
    -- 加2是因为要包括最首的一个数据总长，最尾的一个PEC校验
    return self:send_and_receive(req_bin, track_request.len + header_len + 2)
end

function std_smbus:_concat_response(final_response, rsp_raw)
    -- 第一帧，保存数据头部分
    if final_response.data == '' then
        final_response.error_code = rsp_raw.error_code
        final_response.opcode = rsp_raw.opcode
        final_response.total_length = rsp_raw.total_length
    end

    -- 最后一帧，仅保存有效数据部分
    if final_response.length + rsp_raw.length >= final_response.total_length then
        local start_offset = 1
        local end_offset = rsp_raw.total_length - final_response.length
        final_response.data = final_response.data .. rsp_raw.data:sub(start_offset, end_offset)
        final_response.length = final_response.total_length
        return true
    end
    final_response.data = final_response.data .. rsp_raw.data
    final_response.length = final_response.length + rsp_raw.length
    return false
end

function std_smbus:get_response_empty_obj()
    return response_data_bs:unpack(string.rep('\x00', self.buffer_len), true)
end

function std_smbus:write_read(track_request)
    -- 避免read和write之间没有等待时间
    skynet.sleep(self:get_delay_time())
    -- 发送
    local ok, msg = self:send_and_receive_request_in_frames(track_request)
    if not ok then
        return ok, msg
    end
    -- 接收
    return self:_unpack_request(msg, track_request)
end

function std_smbus:batch_write(track_request)
    local data_batch = {}
    local is_completed = false
    local req_bin
    local crc
    local max_step_size <const> = 2048
    -- 每发送2048字节(2KB)，延时1秒
    repeat
        -- 最后一帧只发剩余长度，其他帧发最大长度
        self:calc_last_frame_len(track_request)

        req_bin = self:construct_request_data(track_request)
        -- 需要添加crc8
        crc = crc8(string.char(self.mcu_address) .. string.char(write_command_code) .. req_bin)
        req_bin = req_bin .. string.char(crc)
        table.insert(data_batch, {write_command_code, req_bin})

        if (track_request.data_written + track_request.len) // max_step_size > 
            track_request.data_written // max_step_size then
            local ok, err = pcall(function()
                return self.ref_chip:BatchWrite(ctx.new(), data_batch)
            end)

            if not ok then
                log:error('batch write failed, %s', err)
                return
            end
            skynet.sleep(100)
            data_batch = {}
        end
        track_request.data_written = track_request.data_written + track_request.len
        is_completed = track_request.data_written >= #track_request.data
    until is_completed

    -- 剩余部分
    if next(data_batch) then
        local ok, err = pcall(function()
            return self.ref_chip:BatchWrite(ctx.new(), data_batch)
        end)

        if not ok then
            log:error('batch write failed, %s', err)
        end
    end
end

-- std_smbus当前无此场景，有需要可以从smbus_5902挪到这里
function std_smbus:write_without_read(track_request)
end

function std_smbus:append_data(track_request, request)
    if not request.align_len then
        return
    end

    local len = #track_request.data % request.align_len
    if len ~= 0 then
        track_request.data = track_request.data .. string.rep('\x00', request.align_len - len)
    end
end

function std_smbus:write_and_read(track_request)
    local retry_count = 0
    local final_response = self:get_response_empty_obj()
    local is_read_completed = false
    local is_write_completed = false
    repeat
        local ok, rsp_raw = self:write_read(track_request)
        if ok then
            -- 重置重试counter
            retry_count = 0
            track_request.data_written = track_request.data_written + track_request.len
            is_write_completed = track_request.data_written >= #track_request.data
            is_read_completed = self:_concat_response(final_response, rsp_raw)
        else
            if retry_count < max_retry_count then
                retry_count = retry_count + 1
                log:debug('failed to retrieve response protocol: %s, msg: %s. Retry: %d/%d',
                    self.name, rsp_raw, retry_count, max_retry_count)
            else
                log:debug(
                    'failed to retrieve response protocol: %s, msg: %s. Reach max retry, exit',
                    self.name, rsp_raw)
                return nil
            end
        end
    until is_write_completed and is_read_completed
    return final_response.data
end

function std_smbus:_send_request_internal(request)
    local track_request = {
        opcode = tonumber(request.opcode),
        data_written = request.offset ~= -1 and request.offset or 0,
        offset_invalid = request.offset == -1,
        len = 0,
        max_frame_len = request.max_frame_len or self:get_max_frame_len(),
        data = request.data or '',
        arg = request.arg or 0,
        expect_data_len = request.expect_data_len or -1,
        data_object_index = request.data_object_index or 0
    }
    -- 按字节对齐要求补齐data
    self:append_data(track_request, request)
    -- 1、请求包含expect_data_len并且expect_data_len不为0
    if request.expect_data_len ~= nil and request.expect_data_len ~= 0 then
        -- 读请求，不定长度，按最大长度读
        if request.expect_data_len < 0 then
            track_request.len = track_request.max_frame_len
        -- 可能是1、读请求，固定长度；2、写请求，回读状态。取最大长度和expect_data_len的小值
        else
            track_request.len = math.min(track_request.max_frame_len, request.expect_data_len)
        end
    -- 2、请求不包含expect_data_len或者expect_data_len为0并且有data
    -- 写请求，不回读状态。取最大长度和data长度的小值
    elseif request.data ~= nil and #track_request.data > 0 then
        track_request.len = math.min(track_request.max_frame_len, #track_request.data)
    -- 3、请求不包含expect_data_len或者expect_data_len为0并且无data，默认分支
    else
        track_request.len = track_request.max_frame_len
    end

    if request.batch_write then
        return self:batch_write(track_request)
    elseif request.write_without_read then
        return self:write_without_read(track_request)
    else
        return self:write_and_read(track_request)
    end
end

function std_smbus:send_request(request)
    if self.parallel then
        return self:_send_request_internal(request)
    end

    return self.queue(function()
        return self:_send_request_internal(request)
    end)
end

function std_smbus:ctor(params)
    if not params or not params.ref_chip or not params.buffer_len then
        log:raise('unable to create smbus with invalid params')
        return
    end
    self.name = 'std_smbus'
    self.ref_chip = params.ref_chip
    self.buffer_len = params.buffer_len
    self.mcu_address = params.mcu_address or default_mcu_address
    self.request_bs = bs.new(frame_template, { body = request_header_bs })
    self.response_bs = bs.new(frame_template, { body = response_data_bs })
    self.write_read_delay = write_read_delay
    self.request_params_template = request_params_template
    self.error_code = error_code
    self.parallel = params.parallel
    self.queue = skynet_queue()
end

return std_smbus
