-- Copyright (c) 2024 Huawei Technologies Co., Ltd.
-- openUBMC is licensed under Mulan PSL v2.
-- You can use this software according to the terms and conditions of the Mulan PSL v2.
--         http://license.coscl.org.cn/MulanPSL2
-- THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-- EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
-- MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
-- See the Mulan PSL v2 for more details.

local bs = require 'mc.bitstring'
local log = require 'mc.logging'
local context = require 'mc.context'
local crc8 = require 'mc.crc8'
local class = require 'mc.class'
local utils = require 'mc.utils'

local RETIMER_CMD_REQUEST<const> = 0x40
local RETIMER_CMD_RESPONSE<const> = 0x41
local ADDRESS_PORT<const> = 0xfff0
local DATA_PORT<const> = 0xfff4
local RECEIVCE_BYTCNT<const> = 7
local REQ_SEND_ADDR_CCODE<const> = 514
local REQ_SEND_DATA_CCODE<const> = 1798

local RESPONSE_BODY<const> = [[<<
    count:8,
    addr:16/little,
    data:32/big
>>]]

local REQUEST_CONTEXT<const> = {
    ['SEND_ADDR'] = {
        pattern = bs.new([[<<
            cmd_code:16/big,
            addr:16/little
        >>]]),
        cmd_code = REQ_SEND_ADDR_CCODE,
        length = 4
    },
    ['SEND_ADDR_ADTA'] = {
        pattern = bs.new([[<<
            cmd_code:16/big,
            addr:16/little,
            data:32/big
        >>]]),
        cmd_code = REQ_SEND_DATA_CCODE,
        length = 8
    }
}

-- 按bit反转一个uint_32整数的高效算法
local function reverse_bits(x)
    x = (((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1))
    x = (((x & 0xcccccccc) >> 2) | ((x & 0x33333333) << 2))
    x = (((x & 0xf0f0f0f0) >> 4) | ((x & 0x0f0f0f0f) << 4))
    x = (((x & 0xff00ff00) >> 8) | ((x & 0x00ff00ff) << 8))
    return ((x >> 16) | (x << 16)) & 0xffffffff
end

-- 封装crc8校验，与read_data比较
local function check_data(check_sum, data)
    local crc = crc8(data)
    if crc ~= check_sum then
        log:debug('get crc: %s, real crc: %s', check_sum, crc)
        return false
    end
    return true
end

local smbus_51632 = class()
function smbus_51632:ctor(ref_chip)
    self.chip = ref_chip
end

function smbus_51632:construct_request(addr, data, pattern, use_pec)
    local request = pattern.pattern:unpack(string.rep('\x00', pattern.length), true)
    request.cmd_code = use_pec and pattern.cmd_code | (1 << 15) or pattern.cmd_code
    request.addr = addr
    if data and request.data then
        request.data = reverse_bits(data)
    end
    local payload = pattern.pattern:pack(request)
    if log:getLevel() >= log.DEBUG then
        log:debug("smbus construct_request, %s", utils.to_hex(payload))
    end
    if use_pec then
        payload = payload .. string.char(crc8(string.char(RETIMER_CMD_REQUEST) .. payload))
    end
    return payload
end

function smbus_51632:smbus_write(addr, data, use_pec)
    local send_data = self:construct_request(addr, data, REQUEST_CONTEXT['SEND_ADDR_ADTA'], use_pec)
    local ok, msg = pcall(self.chip.Write, self.chip, context.get_context_or_default(), RETIMER_CMD_REQUEST, send_data)
    return ok, msg
end

function smbus_51632:smbus_read_quick(addr, use_pec)
    local send_data = self:construct_request(addr, nil, REQUEST_CONTEXT['SEND_ADDR'], use_pec)
    local ok, msg = pcall(self.chip.Write, self.chip, context.get_context_or_default(), RETIMER_CMD_REQUEST, send_data)
    if not ok then
        log:error("smbus begin write %s fail, error: %s", utils.to_hex(send_data), msg)
        return ok, msg
    end
    local read_len = RECEIVCE_BYTCNT + (use_pec and 1 or 0)
    ok, msg = pcall(self.chip.ComboWriteRead, self.chip, context.get_context_or_default(),
        RETIMER_CMD_REQUEST, use_pec and '\x81' or '\x01',
        RETIMER_CMD_RESPONSE, read_len)
    
    if not ok then
        log:error('smbus_read_quick ComboWriteRead fail, error: %s', msg)
        return false, nil
    end
    
    if use_pec then
        local check_buf = '\x40\x81\x41' .. msg:sub(1, #msg - 1)
        if not check_data(string.byte(msg:sub(#msg)), check_buf) then
            log:error('PEC check fail %s', utils.to_hex(check_buf))
            return false, nil
        end
        msg = msg:sub(1, #msg - 1)
    end
    return ok, msg
end

function smbus_51632:retimer_write(args)
    local ok = self:smbus_write(args.addr, args.data, args.use_pec)
    if not ok then
        return false, 'fail to write retimer'
    end
    return true, nil
end

function smbus_51632:retimer_read(args)
    local ok, msg = false, nil
    -- 基于M88RT51632芯片SDK手册，寄存器地址小于0x10000可以直接smbus读取地址；
    -- 对于寄存器地址大于等于0x10000的场景，则需要先写ADDRESS_PORT，再读DATA_PORT。
    if args.addr < 0x10000 then
        ok, msg = self:smbus_read_quick(args.addr, args.use_pec)
        goto next
    end
    ok = self:smbus_write(ADDRESS_PORT, args.addr, args.use_pec)
    if not ok then
        return false, 'smbus_write fail'
    end
    ok, msg = self:smbus_read_quick(DATA_PORT, args.use_pec)
    ::next::
    if not ok then
        return false, 'smbus_read fail'
    end
    return self:parse_response(msg, args.addr)
end

function smbus_51632:retimer_loop_read(args, fun)
    local flag = false
    local ok, msg = false, nil
    for i = 1, 5 do
        ok, msg = self:retimer_read(args)
        if ok and fun(msg) then
            flag = true
            break
        end
    end
    if not flag then
        return false, "fail to get loop stop flag!"
    end
    return self:retimer_read(args)
end

function smbus_51632:parse_response(rawdata, addr)
    local result = bs.new(RESPONSE_BODY):unpack(rawdata, true)
    if result.addr ~= addr then
        log:error("get wrong response, expect addr:%s, actual addr:%s", addr, result.addr)
        return false, "wrong rsp addr"
    end
    local data_reversed = reverse_bits(result.data)
    return true, data_reversed
end

return smbus_51632
