# coding: utf-8
# Copyright (c) 2025 Huawei Technologies Co., Ltd.
# openUBMC is licensed under Mulan PSL v2.
# You can use this software according to the terms and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#         http://license.coscl.org.cn/MulanPSL2
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.

import unittest
import tempfile
import json
import shutil
import tarfile
from pathlib import Path
from functools import wraps
from collections import OrderedDict
from bmcgo.utils.tools import Tools

tools = Tools()
log = tools.log


def split_json(json_string, output_prefix):
    data = json.loads(json_string, object_pairs_hook=OrderedDict)
    b_json = OrderedDict()
    m_json = OrderedDict()
    v_json = OrderedDict()
    for field in ['Unit', 'ManagementTopology']:
        if field in data:
            b_json[field] = data[field]
    if 'Objects' in data and isinstance(data['Objects'], dict):
        o_items = list(data['Objects'].items())
        half_index = (len(o_items) + 1) // 2
        if half_index > 0:
            b_json['Objects'] = OrderedDict(o_items[:half_index])
        if half_index < len(o_items):
            m_json['Objects'] = OrderedDict(o_items[half_index:])
    for field in ['FormatVersion', 'DataVersion']:
        if field in data:
            v_json[field] = data[field]
    with open(f"{output_prefix}_basic_info.sr", 'w', encoding='utf-8') as f:
        json.dump(b_json, f, indent=2, ensure_ascii=False)
    with open(f"{output_prefix}_mgmt_model.sr", 'w', encoding='utf-8') as f:
        json.dump(m_json, f, indent=2, ensure_ascii=False)
    with open(f"{output_prefix}_version.sr", 'w', encoding='utf-8') as f:
        json.dump(v_json, f, indent=2, ensure_ascii=False)

ORIGIN_CSR = '''{
    "FormatVersion": "3.00",
    "DataVersion": "3.00",
    "ManagementTopology": {
        "Anchor": {
            "Buses": [
                "I2c_1"
            ]
        },
        "I2c_1": {
            "Chips": [
                "Chip_MCU",
                "Chip_PcbId",
                "Eeprom_SDI"
            ]
        }
    },
    "Objects": {
        "Eeprom_SDI": {
            "OffsetWidth": 2,
            "AddrWidth": 1,
            "Address": 160,
            "WriteTmout": 100,
            "ReadTmout": 100,
            "RwBlockSize": 32,
            "WriteInterval": 20,
            "HealthStatus": 0
        },
        "SRUpgrade_1": {
            "UID": "00000001A40302050670",
            "Type": "SDI",
            "Version": "${DataVersion}",
            "StorageChip": "#/Eeprom_SDI",
            "SoftwareId": "HWSR-IT21SHSM",
            "WriteProtectChip": "#/Chip_MCU"
        }
    }
}'''

ORIGIN_CSR_DICT = json.loads(ORIGIN_CSR)


class TestThreeCsrBuild(unittest.TestCase):
    def __init__(self, methodName="runTest"):
        super().__init__(methodName)
        self.tmp_dir = None
        self.output_prefix = None
        self.testcase = 0
        self.failed = []

    @staticmethod
    def with_tempdir(func):
        """实例方法装饰器：创建临时文件夹并赋值给self.tmp_dir"""
        @wraps(func)
        def wrapper(self, *args, **kwargs):
            self.testcase += 1
            log.info(f"Testcase {self.testcase} start")
            log.info(func.__doc__)
            self.tmp_dir = Path(tempfile.mkdtemp())
            self.output_prefix = self.tmp_dir / 'test'
            is_pass = True
            try:
                func(self, *args, **kwargs)
            except Exception as e:
                self.failed.append(self.testcase)
                log.error(e)
                is_pass = False
            finally:
                if is_pass:
                    log.info(f"Testcase {self.testcase} passed")
                if self.tmp_dir:
                    shutil.rmtree(self.tmp_dir, ignore_errors=True)
                self.tmp_dir = None
                self.output_prefix = None
        return wrapper
    
    def as_same(self, tar_gz_path, expected_json):
        with tarfile.open(tar_gz_path, 'r:gz') as tar:
            sr_file = None
            for member in tar.getmembers():
                if member.isfile() and member.name.endswith('.sr'):
                    sr_file = member
                    break
            extracted_file = tar.extractfile(sr_file)
            content = extracted_file.read().decode('utf-8')
            actual_json = json.loads(content)
            assert actual_json == expected_json

    @with_tempdir
    def test_single_with_3_csr_without_plain(self):
        """测试含有拆分后的csr但不包含原csr的场景"""
        split_json(ORIGIN_CSR, self.output_prefix)
        cmd = f"bingo build -s -p {self.output_prefix.with_suffix('.sr')} -o {self.tmp_dir} --json"
        tools.run_command(cmd, sudo=True, command_echo=False, ignore_error=True, capture_output=False)
        tar_gz_path = list(self.tmp_dir.glob("*.gz"))[0]
        self.as_same(tar_gz_path, ORIGIN_CSR_DICT)

    @with_tempdir
    def test_single_with_3_csr_with_plain(self):
        """测试同时含有拆分后的csr和与原csr同名csr的场景"""
        split_json(ORIGIN_CSR, self.output_prefix)
        new_dict = ORIGIN_CSR_DICT
        new_dict["DataVersion"] = "2.00"
        with open(self.tmp_dir / "test.sr", "w") as f:
            json.dump(new_dict, f, indent=2, ensure_ascii=False)
        cmd = f"bingo build -s -p {self.output_prefix.with_suffix('.sr')} -o {self.tmp_dir} --json"
        tools.run_command(cmd, sudo=True, command_echo=False, ignore_error=True, capture_output=False)
        tar_gz_path = list(self.tmp_dir.glob("*.gz"))[0]
        self.as_same(tar_gz_path, new_dict)

    @with_tempdir
    def test_lack_csr(self):
        """测试缺少某个拆分csr的场景"""
        split_json(ORIGIN_CSR, self.output_prefix)
        tools.run_command(f"rm -rf {self.output_prefix}_basic_info.sr", sudo=True, 
                          command_echo=False, ignore_error=True, capture_output=False)
        
        cmd = f"bingo build -s -p {self.output_prefix.with_suffix('.sr')} -o {self.tmp_dir} --json"
        tools.run_command(cmd, sudo=True, command_echo=False, ignore_error=True, capture_output=True)
        tar_gz_path = list(self.tmp_dir.glob("*.gz"))
        assert len(tar_gz_path) == 0


if __name__ == '__main__':
    tester = TestThreeCsrBuild()
    tester.test_single_with_3_csr_without_plain()
    tester.test_single_with_3_csr_with_plain()
    tester.test_lack_csr()
    if not tester.failed:
        log.info("Testcases all passed!")
    else:
        log.info("The following testcases failed:")
        log.info(tester.failed)