import argparse
import enum
import sys
import types
import typing

import numpy
import pandas

from .. import utils
from ..libharkio3 import (Config, Positions, TransferFunction,
                          TransferFunctionParser)
from .defs import LogLevel, TFType
from .error import HarkToolError
from .utils import StoreLogLevelAction

ZAxisType = enum.Enum('ZAxisType', [('Amplitude', 'amp'), ('dB', 'db'), ('Phase', 'phase')])

_default_values = types.SimpleNamespace(
    tf_type=TFType.LOC,
    z_axis_type=ZAxisType.Amplitude,
    enable_ifft=False,
    log_level=LogLevel.Error,
    n_fft=512,
    sampling_rate=16000,
)


def parse_args(args):
    parser = argparse.ArgumentParser(
        description='Coordinate from transfer function.')

    parser.add_argument('--tf-type', '--type', '-t', metavar='LOC|SEP', required=False,
                        type=str, dest='tf_type',
                        default=_default_values.tf_type,
                        help='Type of Transfer function (default=%(default)s)')
    parser.add_argument('--z-axis-type', '--zaxis', '-z', metavar='NAME', required=False,
                        type=str, dest='z_axis_type',
                        default=_default_values.z_axis_type,
                        help='Z_axis type (default=%(default)s)')
    parser.add_argument('--enable-ifft', '--ifft', '-i', required=False,
                        dest='enable_ifft', action='store_true',
                        default=_default_values.enable_ifft,
                        help='Enable ifft (z_axis=real)')
    parser.add_argument('--log-level', '--llevel', metavar='{E|W|I|D}', required=False,
                        type=LogLevel, action=StoreLogLevelAction, dest='log_level',
                        default=_default_values.log_level,
                        help='Log information level. (default=%(default)s)', )
    parser.add_argument(metavar='INFILE',
                        type=str, dest='input_file',
                        help='Input zip file')
    parser.add_argument(metavar='OUTFILE',
                        type=str, dest='output_file',
                        help='Output csv file')
    return parser.parse_args(args)


def main(args=sys.argv[1:]):
    args = parse_args(args)

    logger = utils.initialize_logger(args)

    zip_to_coordinate(**vars(args), logger=logger)


def zip_to_coordinate(
    input_file: str,
    output_file: str,
) -> None:

    # load tf
    tf: TransferFunction = TransferFunctionParser.from_file(input_file)

    # generate coord
    coord = coordinate_from_tf(tf)

    # write coord
    write_coordinate(output_file, coord)


def write_coordinate(output_file: str | typing.IO, coord: typing.Any) -> None:
    pass


def coordinate_from_tf(tf: TransferFunction, z_axis_type: ZAxisType, tf_type: TFType, enable_ifft: bool) -> typing.Any:
    # int writeCoordinateFromTF(const char *outfile, int z_axis, harkio_TransferFunction* tf, enum TFType tf_type, bool ifft) {
    # float x_weight;
    # float y_weight = 0.2;
    # int x_num;

    positions: Positions = tf.positions
    config: Config = tf.config
    y_weight = 0.2

    if enable_ifft:
        x_weight = 1.0 / config.sampling_rate
        x_num = config.n_fft
    else:
        x_weight = config.sampling_rate / config.n_fft
        x_num = config.n_fft // 2 + 1

    # harkio_Matrix** matdata = NULL;

    if tf_type == TFType.LOC:
        matrix_data = tf.loc_tfs
    elif tf_type == TFType.SEP:
        matrix_data = tf.sep_tfs
    else:
        raise HarkToolError(f'Unexpected tf_type: {tf_type}')

    if not matrix_data:
        raise HarkToolError(f'No matrix found in tf data')

    if z_axis_type in (ZAxisType.Phase, ZAxisType.Real) or enable_ifft:
        offset = get_offset(tf, tf_type)

    if enable_ifft:
        offset /= config.sampling_rate if config.sampling_rate > 0 else _default_values.sampling_rate

        # // 結果格納用の行列
        # ifft_result = numpy.zeros((len(positions), len(microphones), config.n_fft), dtype=numpy.complex128)

        if tf_type == TFType.SEP:
            # // When TF is separation, replace matrix with its complex conjugate
            # harkio_Matrix* conj_matdata[tf->poses->size];
            # for (int src = 0; src < tf->poses->size; src++) {
            #     //conj_matdata[src] = harkio_Matrix_econj(matdata[src]); //MinGWの古いgccでエラーになるので保留
            #     conj_matdata[src] = harkio_Matrix_new(matdata[src]->type, matdata[src]->rows, matdata[src]->cols);
            #     for(int mic = 0; mic < matdata[src]->rows; mic++) {
            #         for(int fft = 0; fft < matdata[src]->cols; fft++) {
            #             COMPLEX_TYPE ctmp;
            #             harkio_Matrix_getValueComplex(matdata[src], mic, fft, &ctmp);
            #             harkio_Matrix_setValueComplex(conj_matdata[src], mic, fft, conjf(ctmp));
            #         }
            #     }
            # }
            # // data get
            # calcIFFTMatrixData(conj_matdata, ifft_result, tf->poses->size, tf->mics->size, tf->cfg->nfft);
            # // free
            # for (int src = 0; src < tf->poses->size; src++) {
            #     harkio_Matrix_delete(&conj_matdata[src]);
            # }
            ifft_result = calculate_ifft_matrix(matrix_data.conj(), config.n_fft)

        else:
            ifft_result = calculate_ifft_matrix(matrix_data, config.n_fft)
            # // data get
            # calcIFFTMatrixData(matdata, ifft_result, tf->poses->size, tf->mics->size, tf->cfg->nfft);
        # }

        return calculate_coordinate(ZAxisType.Real, ifft_result, positions, x_num, x_weight, y_weight, offset)

    else:
        if z_axis_type in (ZAxisType.Real, ZAxisType.Phase):
            n_fft = config.n_fft if config.n_fft else _default_values.n_fft
            offset *= 2.0 * numpy.pi / n_fft
            return calculate_coordinate(z_axis_type, matrix_data, positions, x_num, x_weight, y_weight, offset)
            # return writeCoordinateInternal(outfile, z_axis, matdata, tf->poses, matdata[0]->rows, x_num, x_weight, NULL,
            #         y_weight, offset);
        else:
            return calculate_coordinate(z_axis_type, matrix_data, positions, x_num, x_weight, y_weight, None)


def calculate_z(z_axis_type: ZAxisType, z: complex) -> float:
    if z_axis_type == ZAxisType.Real:
        return z.real

    elif z_axis_type == ZAxisType.Amplitude:
        return numpy.abs(z)

    elif z_axis_type == ZAxisType.dB:
        return numpy.log10(numpy.abs(z)) * 10

    elif z_axis_type == ZAxisType.Phase:
        return numpy.angle(z, deg=True)


def calculate_coordinate(
    z_axis_type: ZAxisType, matrix_data: numpy.ndarray,
    sources: Positions, microphones: Positions,
    x_num: int,
    x_weight: float, x_offset: numpy.ndarray, y_weight: float, z_offset: numpy.ndarray
) -> numpy.ndarray:
    # int writeCoordinateInternal(const char* outfile, int z_axis, harkio_Matrix** matdata, const harkio_Positions* poses,
    #         int num_mic, int num_x, float x_weight, const float* x_offset, float y_weight, const float* z_offset) {

    # int num_src = poses->size;
    # int micPosIds[num_mic];
    # if (num_src > 0 && poses->pos[0]->channelsUse == 1) {
    #     num_mic = poses->pos[0]->numChannels;
    #     for (int i = 0; i < num_mic; i++) {
    #         micPosIds[i] = poses->pos[0]->channels[i];
    #     }
    # } else {
    #     for (int i = 0; i < num_mic; i++) {
    #         micPosIds[i] = i;
    #     }
    # }
    if len(sources) > 0 and sources[0].channels_use == USE_TAG:
        microphone_position_ids = sources[0].channels
    else:
        microphone_position_ids = list(range(len(microphones)))

    # //printf("writeCoordinate start\n");

    # FILE *fpw;
    # float z_data;
    # int src;
    # int mic;
    # int x;

    # if (!outfile)
    #     fpw = stdout;
    # else if ((fpw = fopen(outfile, "w")) == NULL) {
    #     harkio_Log_printf(harkio_Log_Error, "fopen error [%s]", outfile);
    #     exit(EXIT_FAILURE);
    # }

    # //write header
    # fprintf(fpw, "id,mic,x,y,z\n");

    # for (src = 0; src < num_src; src++) {
    #     for (mic = 0; mic < num_mic; mic++) {
    #         for (x = 0; x < num_x; x++) {
    #             // x_axis
    #             float x_data = x * x_weight + ((x_offset == NULL) ? 0.0 : x_offset[src]); // TIME[s] or FREQ
    #             // y_axis
    #             float y_data = src * y_weight;

    #             // z_axis
    #             COMPLEX_FLOAT z_complex_data = matdata[src]->data.complexdata[mic * num_x + x];
    #             if (z_offset) {
    #                 z_complex_data *= cexpf(-z_offset[src] * x * _Complex_I);
    #             }
    #             z_data = calcZValue(z_axis, crealf(z_complex_data), cimagf(z_complex_data));

    #             fprintf(fpw, "%d , %d , %f , %f , %f\n", poses->pos[src]->id, micPosIds[mic], x_data, y_data, z_data);
    #         }
    #     }                //mic for
    # }                //src for

    rows = []
    for s_index, s_matrix, s_offset in enumerate(zip(matrix_data, x_offset)):
        for m_index in microphone_position_ids:
            for x_index in range(x_num):
                x = x_index * x_weight + s_offset
                y = s_index * y_weight

                z = s_matrix[m_index][x_index]
                if z_offset:
                    z *= numpy.exp(-z_offset[s_index] * x_index * 1j)
                z = calculate_z(z_axis_type, z)

    return pandas.DataFrame(data=rows, columns=('id', 'mic', 'x', 'y', 'z'))

#     fclose(fpw);

#     return EXIT_SUCCESS;
# }


def get_offset(tf_type: TFType, tf: TransferFunction) -> numpy.ndarray:
    if tf_type == TFType.LOC:
        tf_key = 'loc'
    elif tf_type == TFType.SEP:
        tf_key = 'sep'

    return numpy.array(getattr(p.tfcalc_params, tf_key, {}).get('offset', 0) for p in tf.positions)

#     char* srcPath = "transferFunction/source.xml";
#     char* strSrcXml = NULL;
#     if (as_zip_fread(tf->handle, srcPath, &strSrcXml) != EXIT_SUCCESS) {
#         return EXIT_FAILURE;
#     }
#     // parse and modify source.xml
#     xmlInitParser();
#     xmlDocPtr doc = xmlParseDoc(BAD_CAST strSrcXml);
#     free(strSrcXml);
#     if (doc == NULL) {
#         harkio_Log_printf(harkio_Log_Error, "source.xml parse error");
#         return EXIT_FAILURE;
#     }
#     xmlXPathContextPtr xpathCtx = xmlXPathNewContext(doc);
#     if (xpathCtx == NULL) {
#         harkio_Log_printf(harkio_Log_Error, "xmlXPathNewContext error");
#         xmlFreeDoc(doc);
#         return EXIT_FAILURE;
#     }
#     for (int i = 0; i < tf->poses->size; i++) {
#         char xpathExpr[100];
#         sprintf(xpathExpr, "/hark_xml/positions/position[@id=%d]/calcTF[@type=%s]/@offset", tf->poses->pos[i]->id,
#                 (tf_type == LOC) ? "\"localization\"" : "\"separation\"");
#         xmlXPathObjectPtr xpathObj = xmlXPathEvalExpression(BAD_CAST xpathExpr, xpathCtx);
#         if (xpathObj == NULL) {
#             harkio_Log_printf(harkio_Log_Error, "xmlXPathEvalExpression error");
#             xmlXPathFreeContext(xpathCtx);
#             xmlFreeDoc(doc);
#             return EXIT_FAILURE;
#         }
#         if (xpathObj->nodesetval != NULL && xpathObj->nodesetval->nodeNr > 0) {
#             char* strOffset = (char*) xmlNodeGetContent(xpathObj->nodesetval->nodeTab[0]);
#             offset[i] = (strOffset == NULL) ? 0.0 : atof((char*) strOffset);
#             xmlFree(strOffset);
#         } else {
#             offset[i] = 0.0;
#         }
#         xmlXPathFreeObject(xpathObj);
#     }
#     xmlXPathFreeContext(xpathCtx);
#     xmlFreeDoc(doc);
#     xmlCleanupParser();

#     return EXIT_SUCCESS;
# }


def calculate_ifft_matrix(freq_data: numpy.ndarray) -> numpy.ndarray:
    return numpy.fft.irfftn(freq_data)

# void calcIFFTMatrixData(harkio_Matrix** freqdata, harkio_Matrix** timedata, int num_src, int num_mic, int num_fft) {
#     int cnt;
#     int src;
#     int mic;
#     int fft;

#     COMPLEX_FLOAT* buff_A;
#     float* float_a;

#     buff_A = (COMPLEX_FLOAT*) malloc((num_fft / 2 + 1) * sizeof(COMPLEX_FLOAT));
#     if (buff_A == NULL) {
#         harkio_Log_printf(harkio_Log_Error, "malloc error");
#         exit(EXIT_FAILURE);
#     }
#     float_a = (float*) malloc(num_fft * sizeof(float));
#     if (float_a == NULL) {
#         harkio_Log_printf(harkio_Log_Error, "malloc error");
#         exit(EXIT_FAILURE);
#     }

#     // iFFTの実行
#     for (src = 0; src < num_src; src++) { // f_max=g.Nsrc )
#         for (mic = 0; mic < num_mic; mic++) {
#             for (fft = 0; fft < num_fft / 2 + 1; fft++) {
#                 cnt = fft + mic * (num_fft / 2 + 1);
#                 buff_A[fft] = freqdata[src]->data.complexdata[cnt];
#             }

#             // fft実行 A[]の値をifftした結果がa[]にセットされる
#             // Aのサイズ(複素対称)[0:Nfft/2] --> aのサイズ(実数のみ)[0:Nfft-1]
#             // harktool_calcFloatIFFT( buff_A , float_a );
#             harktool_FFTImpl_fRealFFTInv(buff_A, float_a, num_fft);

#             for (fft = 0; fft < num_fft; fft++) {
#                 cnt = fft + mic * (num_fft);
#                 timedata[src]->data.complexdata[cnt] = (COMPLEX_FLOAT) float_a[fft];
#             }
#         }
#     }
#     free(buff_A);
#     free(float_a);
# }

# end of file
