#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
import ast
import argparse
import json
import os
import sys
import importlib
UTIL_PATH = "/usr/share/bmc_studio/server/tools/plugins/util"
sys.path.append(UTIL_PATH)

WHITELISTED_CLASSES = [
    "issue",
    "grammar_rule"
]


class DangerousNodeVisitor(ast.NodeVisitor):
    def __init__(self):
        self.allowed_sys_function = 'sys.path.append'
        self.dangerous_function = [
            'system', 'popen', 'eval', 'exec', 'spawn', 'retree', 'input', 'open', 
            'compile', 'execfile', '__import__', 'id'
        ]
        self.danger_library = ['sys', 'os', 'subprocess', 'shutil', 'multiprocessing']
    
    def visit_Import(self, node):
        for alias in node.names:
            if alias.name in self.danger_library:
                raise ValueError(f"危险操作：禁止导入 {alias.name}")
        self.generic_visit(node)

    def visit_ImportFrom(self, node):
        if node.module in self.danger_library:
            raise ValueError(f"禁止从 {node.module} 模块中导入内容")
        self.generic_visit(node)

    def visit_Call(self, node):
        if isinstance(node.func, ast.Name):
            func_name = node.func.id
        elif isinstance(node.func, ast.Attribute):
            func_name = node.func.attr
        else:
            func_name = None
        
        if func_name in self.dangerous_function:
            raise ValueError(f"危险操作：禁止调用 {func_name}")
        
        if isinstance(node.func, ast.Attribute):
            if (isinstance(node.func.value, ast.Name) and node.func.value.id == 'importlib' and
                node.func.attr == 'import_module'):
                module_name = None
                if isinstance(node.args[0], ast.Constant):
                    module_name = node.args[0].value
                if module_name and module_name not in WHITELISTED_CLASSES:
                    raise ValueError(f"危险操作：禁止动态导入 {module_name}")

        self.generic_visit(node)


class ModelCheckExtend:
    def __init__(self, model: str, path: str, enabled: str):
        self.category_model = {
            "CSR": ["csr"],
            "自定义类型": ["mds", "resource_tree"],
            "部件接口": ["device_tree"],
            "接口映射": ["interface_mapping"],
            "IPMI": ["mds"],
            "MDS": ["mds"],
            "资源协作接口": ["resource_tree"]
        }
        self.model = model
        self.rules = self.load_rules(path, enabled)

    @staticmethod
    def check_args(args):
        input_file = args.input
        if input_file is None or not os.path.exists(input_file):
            raise TypeError(f"缺少参数或输入文件{input_file}不存在")
        data_file = args.data
        if data_file is None or not os.path.exists(data_file):
            raise TypeError(f"缺少参数或依赖数据文件{data_file}不存在")
        file_rootpath = args.rootpath
        if file_rootpath is None or not os.path.exists(file_rootpath):
            raise TypeError(f"缺少参数或根路径{file_rootpath}不存在")
        output_file = args.output
        if output_file is None or not os.path.exists(output_file):
            raise TypeError("缺少参数或结果输出文件不存在")
        return
    
    @staticmethod
    def load_data(data_file: str):
        with open(data_file, 'r', encoding='utf-8') as file:
            data = json.load(file)

        return data
    
    def load_rules(self, path: str, enabled: str):
        with open(os.path.join(path, 'rule_definition/grammar_check.json'), 'r', encoding='utf-8') as file:
            data = json.load(file)
        if not isinstance(data, dict):
            raise ValueError("grammar_check.json格式错误，期望为字典形式")
        try:
            enabled_list = ast.literal_eval(enabled)
        except (ValueError, SyntaxError) as e:
            enabled_list = []
            print(f"Error converting string to list: {e}")
        
        rules = {}
        src_path = os.path.join(args.path, 'src')
        for file_name in sorted(os.listdir(src_path)):
            if not file_name.endswith(".py") or file_name == "__init__.py":
                continue
            module_name = file_name[:-3]  # 去掉 .py 后缀
            if module_name not in data or module_name not in enabled_list:
                continue
            if not (isinstance(data[module_name], dict) and 
                    self.model in self.category_model.get(data[module_name].get("category", ""), [])):
                continue
            try:
                with open(os.path.join(src_path, file_name), 'r', encoding='utf-8') as f:
                    code = f.read()
                tree = ast.parse(code)
            except SyntaxError as e:
                print(f"{module_name}.py has syntax error: line {e.lineno}, offeset {e.offset}, error {e.msg}", 
                      flush=True)
                continue
            try:
                # 先进行权限检查
                visitor = DangerousNodeVisitor()
                visitor.visit(tree)
            except Exception as e:
                print(f"The source code of rule {module_name} contains risky operations: {e}", flush=True)
                continue
            
            module = importlib.import_module(module_name)  # 动态导入模块
            for attr_name in dir(module):  # 获取模块中的所有类
                if attr_name == "GrammarIssue" or attr_name == "GrammarRule":
                    continue
                attr = getattr(module, attr_name)
                if isinstance(attr, type):  # 检查是否为类
                    rules[module_name] = attr(data[module_name])
            
            print(f"load rule {module_name} success.", flush=True)

        return rules

    def run(self, args):
        self.check_args(args=args)
        input_file = args.input
        datamap = self.load_data(data_file=args.data)
        files_data = self.load_files(input_file=input_file, file_rootpath=args.rootpath)
        first_write = True
        for rule_id, rule in self.rules.items():
            results = []
            print(f"{rule_id} start check.", flush=True)
            try:
                for (file_path, item) in files_data.items():
                    issues = rule.validate(item['file_data'], self.model, file_path, 
                                            file_type=item["file_type"], require_data=datamap)
                    results.extend(issues)
                print(f"{rule_id} check success.", flush=True)
            except Exception as e:
                print(f"{rule_id} check failed, error: {e}", flush=True)
            finally:
                sys.stdout.flush()
            if not results:
                continue
            with open(args.output, 'r+', encoding='utf-8') as file:
                data_str = json.dumps(results, indent=4)[1:-1]
                if first_write:
                    file.write('[' + data_str + ']')
                    first_write = False
                else:
                    file.seek(0, 2)
                    file.seek(file.tell() - 2)  # 定位到最后一个字符之前
                    file.write(',' + data_str + ']')
        
        return
    
    def load_files(self, input_file: str, file_rootpath: str):
        items = self.load_data(input_file)
        files_data = {}
        for (_, item) in items.items():
            file_path = item.get("filepath", "")
            try:
                with open(os.path.join(file_rootpath, file_path), 'r', encoding='utf-8') as file:
                    file_data = json.load(file)
                files_data[file_path] = {
                    "file_path": file_path,
                    "file_type": item.get("filetype", ""),
                    "file_data": file_data
                }
            except Exception as e:
                print(f"文件{file_path}内容读取失败，报错：{e}", flush=True)

        return files_data


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="扩展规则调用入口，用于模型检查")
    parser.add_argument("-i", "--input", help="输入文件，存放待检查文件列表")
    parser.add_argument("-r", "--rootpath", help="待检查文件的根路径")
    parser.add_argument("-d", "--data", help="检查依赖数据文件")
    parser.add_argument("-o", "--output", help="输出文件")
    parser.add_argument("-m", "--model", help="模型名称")
    parser.add_argument("-p", "--path", help="扩展规则实现脚本路径")
    parser.add_argument("-e", "--enabled", help="启用的扩展规则列表，包括选中的符合规范的规则")
    args, _ = parser.parse_known_args()
    sys.path.append(os.path.join(args.path, 'src'))
    ModelCheckExtend(args.model, args.path, args.enabled).run(args)
