# Copyright (C) 2023 Intel Corporation
# SPDX-License-Identifier: MIT


import ast
import logging
from typing import Dict


def generate_numpy_vectorized_code(exp_in: str, symbol_table: Dict = None, metric_name: str = None) -> str:
    """
    Generates the numpy vectorized code for the passed in expression
    using the symbol_table if defined.
    :param exp_in: expression for which to generate vectorized code
    :param symbol_table: dictionary that map symbol names to their values
    :metric_name: metric name for this operation
    :return: generated Python code as string
    """
    symbol_table = symbol_table if symbol_table else {}
    metric_name = metric_name if metric_name else ''

    try:
        exp_in = exp_in.strip()
        exp = ast.parse(exp_in, mode='eval')
        generated_code = _ast_gen(exp.body, symbol_table)
        return __determine_if_not_expression(generated_code)
    except Exception as e:
        __handle_exceptions(e, symbol_table, metric_name)


def __determine_if_not_expression(generated_code):
    if '[' not in generated_code and '(' not in generated_code:
        return __convert_to_column_of_constants(generated_code)
    return generated_code


def __convert_to_column_of_constants(generated_constant):
    return "pd.DataFrame({0}*np.ones(shape=(len(df),1)), index=df.index)".format(generated_constant)

def __handle_exceptions(e: Exception, symbol_table: Dict, metric_name: str):
    default_exception_msg =  f'{e} in {metric_name}.'
    if isinstance(e, TypeError) and 'NoneType' in str(e):
        _handle_none_type_type_error(e, symbol_table, metric_name)
    elif SyntaxError:
        raise SyntaxError(default_exception_msg)
    elif KeyError:
        raise KeyError(default_exception_msg)
    else:
        raise Exception(default_exception_msg)


def _handle_none_type_type_error(e: Exception, symbol_table: Dict, metric_name: str):
    if __has_asymmetric_socket_configuration(symbol_table):
        logging.warning(f'{metric_name} not computed - asymmetric socket configurations are not supported.')
        return None
    else:
        raise TypeError(f'{e} in {metric_name}.')


def __has_asymmetric_socket_configuration(symbol_table: Dict):
    socket_core_count_key = 'system.sockets[0].cores.count'
    socket_thread_count_key = 'system.sockets[0].cpus.count'
    if (socket_core_count_key in symbol_table and not symbol_table[socket_core_count_key]) or \
       (socket_thread_count_key in symbol_table and not symbol_table[socket_thread_count_key]):
        return True
    return False


def is_number(value) -> bool:
    """
    Tests if a value is a number
    :param value: value to test
    :return: True if `value` is an integer or floating point number, False otherwise
    """
    try:
        if type(value) is str:
            value = value.replace(',', '')
        float(value)
        return True
    except ValueError:
        return False


class _NameVisitor(ast.NodeVisitor):
    def __init__(self, symbol_table):
        self.__names = set()
        self.__constants_only = ''
        self.__symbol_table = symbol_table

    def visit_Name(self, node):
        self.__names.add(_ast_gen(node, self.__symbol_table))
        self.__constants_only = 'False'

    def visit_Constant(self, node):
        self.__constants_only = 'True'

    @property
    def names(self):
        return sorted(self.__names)

    @property
    def constants_only(self):
        return self.__constants_only


def _generate_all_finite_test(node, symbol_table):
    visitor = _NameVisitor(symbol_table)
    visitor.visit(node)
    condition = ' & '.join([f'np.isfinite({name})' for name in visitor.names])
    if not condition:
        condition = visitor.constants_only
    return condition


class _MinMaxCallable:
    def __init__(self, function_name: str):
        self.__function_name = function_name

    @property
    def name(self):
        return self.__function_name

    def generate_code(self, args, symbol_table):
        self._validate_args(args)
        args = ', '.join([_ast_gen(arg, symbol_table) for arg in args])
        return f'{self.name}( {args} )'

    @staticmethod
    def _validate_args(args):
        # Verify min() and max() invocation syntax. Allowed variants:
        # 1. A single list argument with exactly 2 values, e.g. min([1,2])
        # 2. Two arguments, e.g. min(1,2)
        if len(args) == 1:
            if type(args[0]) is not ast.List or len(args[0].elts) != 2:
                raise SyntaxError(f'List argument of min() and max() functions must have exactly 2 values')
        elif len(args) != 2:
            raise SyntaxError(f'min() and max() functions accept either 2 arguments or a single list argument '
                              f'with 2 values')


class _MinCallable(_MinMaxCallable):
    def __init__(self):
        super().__init__('np.minimum')


class _MaxCallable(_MinMaxCallable):
    def __init__(self):
        super().__init__('np.maximum')


__call_for_func_id = {
    'min': _MinCallable(),
    'max': _MaxCallable(),
}

__operators = {
    ast.Add: ' + ',
    ast.Sub: ' - ',
    ast.Mult: ' * ',
    ast.Div: ' / ',
    ast.FloorDiv: ' // ',
    ast.BitOr: ' | ',
    ast.Lt: ' < ',
    ast.Gt: ' > ',
    ast.LtE: ' <= ',
    ast.GtE: ' >= ',
    ast.Eq: ' == ',
    ast.USub: ' -',
    ast.UAdd: ' +',
    ast.BitAnd: ' & ',
    ast.Pow: ' ** '
}


def __handle_if_expr(node, symbol_table):
    # Convert a ternary expression of the form "a if condition else b" to the following numpy vector expression:
    # np.where(all_operands_are_finite(condition), np.where(condition, a, b), np.nan), which means:
    #     "if all operands of 'condition' are finite numbers (not inf, -inf or NaN), execute the ternary expression,
    #      otherwise, return NaN".
    #
    # The "wrapper" np.where() is used for differentiating between the following cases:
    # 1. "condition" is evaluated to False because it is indeed False
    # 2. "condition" is evaluated to False because one of its operands is not a finite number
    #    (np.where evaluates NaN to False).
    #
    # This technique ensures that the vectorized expression will evaluate to NaN if any of the operands of "condition"
    # are not finite numbers, which is the expected behavior for computing sample-level metrics.

    condition = _ast_gen(node.test, symbol_table)
    return f'np.where({_generate_all_finite_test(node.test, symbol_table)}, ' \
           f'np.where({condition}, {_ast_gen(node.body, symbol_table)}, {_ast_gen(node.orelse, symbol_table)}), ' \
           f'np.nan ' \
           f')'


def __handle_ruby_call(node, symbol_table):
    return __generate_function_call(node.attr, [node.value], symbol_table)


def __handle_call(node, symbol_table):
    return __generate_function_call(node.func.id, node.args, symbol_table)


def __generate_function_call(function_name, function_args, symbol_table):
    if function_name not in __call_for_func_id:
        raise SyntaxError(f'Unsupported function: {function_name}')

    func = __call_for_func_id[function_name]
    return func.generate_code(function_args, symbol_table)


def __handle_list(node, symbol_table):
    return ', '.join(list(map(lambda x: _ast_gen(x, symbol_table), node.elts)))


def __handle_name(node, symbol_table):
    if node.id in symbol_table:
        constant_val = symbol_table[node.id]
        if is_number(constant_val):
            return str(constant_val)
        else:
            return f'df[\'{constant_val}\']'
    else:
        return str(node.id)


def __handle_compare(node, symbol_table):
    # TODO: write unit tests for this
    return _ast_gen(node.left, symbol_table) + __operators[type(node.ops[0])] \
        + _ast_gen(node.comparators[0], symbol_table)


def __handle_subscript(node, symbol_table):
    if isinstance(node.slice, ast.Index):  # Python 3.8 and below
        index = _ast_gen(node.slice.value, symbol_table)
    else:  # Python 3.9+
        index = _ast_gen(node.slice, symbol_table)
    return f'{_ast_gen(node.value, symbol_table)}[{index}]'


def __handle_slice(node, symbol_table):
    lower_slice = node.lower if node.lower else ast.Constant('')
    upper_slice = node.upper if node.upper else ast.Constant('')
    return f'{_ast_gen(lower_slice, symbol_table)}:{_ast_gen(upper_slice, symbol_table)}'


def __handle_binary_operator(node, symbol_table):
    if BitwiseOperaton.is_bitwise(type(node.op)):
        return BitwiseOperaton(node, symbol_table).operation
    return f'({_ast_gen(node.left, symbol_table)}{_get_operator(node)}{_ast_gen(node.right, symbol_table)})'


def _get_operator(node):
    return __operators[type(node.op)]


class BitwiseOperaton:
    DF = 'df'
    INT = 'int'
    INT64 = 'Int64'

    def __init__(self, node, symbol_table):
        self.__node = node
        self.__symbol_table = symbol_table
        self.__operation = self.__get_operation()

    @property
    def operation(self):
        return self.__operation

    @staticmethod
    def is_bitwise(op_type):
        return op_type in [ast.BitAnd, ast.BitOr]

    @staticmethod
    def __is_vectorized(name):
        return True if BitwiseOperaton.DF in name else False

    @staticmethod
    def __to_int(name):
        # Bitwise operations can only be performed on integers
        # If the operand name doesn't already include integer casting, add it.
        if BitwiseOperaton.__is_vectorized(name) and BitwiseOperaton.INT64 not in name:
            return f"{name}.astype('{BitwiseOperaton.INT64}')"
        elif BitwiseOperaton.INT not in name:
            return f'int({name})'
        return name

    def __get_operation(self):
        operator = _get_operator(self.__node)
        operands = [f'{_ast_gen(self.__node.left, self.__symbol_table)}',
                    f'{_ast_gen(self.__node.right, self.__symbol_table)}']
        operation = f'{operator}'.join([f'{self.__to_int(operand)}' for operand in operands])
        return f'({operation})'


def __handle_unary_operator(node, symbol_table):
    return f'{_get_operator(node)}{_ast_gen(node.operand, symbol_table)}'


__formulas = {
    ast.Constant: lambda node, symbol_table: str(node.n),
    ast.Num: lambda node, symbol_table: str(node.n),
    ast.Name: __handle_name,
    ast.List: __handle_list,
    ast.BinOp: __handle_binary_operator,
    ast.UnaryOp: __handle_unary_operator,
    ast.IfExp: __handle_if_expr,
    ast.Call: __handle_call,
    ast.Attribute: __handle_ruby_call,
    ast.Compare: __handle_compare,
    ast.Subscript: __handle_subscript,
    ast.Slice: __handle_slice,
}


def _ast_gen(node, symbol_table) -> str:
    return __formulas[type(node)](node, symbol_table)
