#!/usr/bin/python
# -*- coding: utf-8 -*-
"""

harktool/app/calctfgeo.py

"""

import io
import logging
import types

import numpy

from .. import utils
from ..libharkio3 import (Config, HarkXML, HarkXMLParser, Neighbors, Positions,
                          TransferFunction, TransferFunctionParser)
from ..libharkio3.defs import CoordinateSystem, UseTagType  # POSITION_INFO,
from .defs import SOUND_SPEED, LogLevel, TFType
from .error import HarkToolError
from .utils import (StoreLogLevelAction, StoreTFTypeAction,
                    select_output_channels)

_default_values = types.SimpleNamespace(
    output_type=TFType.LOC_SEP,
    direct_length=32,
    normalize_source=False,
    normalize_microphone=False,
    normalize_frequency=False,
    reset_microphone=False,
    log_level=LogLevel.Error
)


def setup_parser(parser):
    parser.add_argument('--tsp-list', '--tlist', '-i', metavar='PATH', required=True,
                        type=str, dest='tsp_list',
                        help='TSP list file  (eg. tsp.xml)')
    parser.add_argument('--microphone-list', '--mlist', '-m', metavar='PATH', required=True,
                        type=str, dest='microphone_list',
                        help='Microphone list file  (eg. microphone.xml)')
    parser.add_argument('--output-type', '--otype', '-t', metavar='|'.join(e.name for e in TFType), required=False, default=_default_values.output_type,
                        type=TFType, action=StoreTFTypeAction, dest='output_type',
                        help='Transfer function type  LOC=Localization,SEP=Separation,LOC_SEP=LOC&SEP (default=%(default)s)')
    parser.add_argument('--output-file', '--ofile', '-o', metavar='PATH', required=True,
                        type=str, dest='output_file',
                        help='Output transfer function zip file path  (eg. /path/to/transfer.zip)')
    parser.add_argument('--direct-length', '--dlen', '-d', metavar='NUM', required=False, default=_default_values.direct_length,
                        type=int, dest='direct_length',
                        help='Direct sound length')
    parser.add_argument('--normalize-source', '--snorm', '-S', metavar='True|False', required=False,
                        type=bool, dest='normalize_source',
                        default=_default_values.normalize_source, help='Normalize by source')
    parser.add_argument('--normalize-microphone', '--mnorm', '-M', metavar='True|False', required=False, default=_default_values.normalize_microphone,
                        type=bool,
                        help='Normalize by microphone', dest='normalize_microphone')
    parser.add_argument('--normalize-frequency', '--fnorm', '-F', metavar='True|False', required=False, default=_default_values.normalize_frequency,
                        type=bool, dest='normalize_frequency',
                        help='Normalize by frequency')
    parser.add_argument('--reset-microphone', '--mreset', '-R', default=_default_values.reset_microphone,
                        action='store_true', dest='reset_microphone',
                        help='Remove unused channels from microphones.xml and reset their ids. (for HARKTOOL4 compatible)')
    parser.add_argument('--log-level', '--llevel', metavar='{E|W|I|D}', required=False, default=_default_values.log_level,
                        type=LogLevel, action=StoreLogLevelAction, dest='log_level',
                        help='Log information level. (default=%(default)s)')

    parser.set_defaults(handler=main)


def main(args):
    logger = utils.initialize_logger(args)

    # validate
    if args.direct_length < 0:
        raise HarkToolError(
            'Unexpected value for --direct-length/-d.  Positive value is expected.')

    calculate_geometric_tf(**vars(args), logger=logger)


def calculate_geometric_tf(
        tsp_list: str,
        microphone_list: str,
        output_type: TFType,
        output_file: str,
        direct_length: int,
        normalize_source: bool,
        normalize_microphone: bool,
        normalize_frequency: bool,
        reset_microphone: bool,
        logger: logging.Logger,
        **kwargs) -> None:

    logger.info("""harktool_calctfget    tsp_file      [{}]
                    , microphone_file      [{}]
                    , output_type   [{}]
                    , output_file   [{}]
                    , direct_length [{}]
                    , normalize_source      [{}]
                    , normalize_microphone      [{}]
                    , normalize_frequency     [{}]
                    , reset_microphone     [{}]
""".format(
        tsp_list, microphone_list, output_type, output_file, direct_length, normalize_source, normalize_microphone, normalize_frequency, reset_microphone))

    message = io.StringIO()

    # read XML files
    try:
        srcFileXML: HarkXML = HarkXMLParser.from_file(tsp_list)
        logger.info(f'Read tsp list file: {tsp_list}')
    except BaseException as ex:
        raise HarkToolError(
            f'Failed to read tsp list file: {tsp_list}') from ex

    try:
        micFileXML: HarkXML = HarkXMLParser.from_file(microphone_list)
        logger.info(f'Read microphone list file: {microphone_list}')
    except BaseException as ex:
        raise HarkToolError(
            f'Failed to read microphone list file: {microphone_list}') from ex

    # retrieve data
    source_neighbors: Neighbors = srcFileXML.neighbors  # may be None

    source_config: Config = srcFileXML.config
    if source_config is None:
        raise HarkToolError('Config not found in source xml')

    source_positions: Positions = srcFileXML.positions
    if source_positions is None:
        raise HarkToolError('Positions not found in source xml')

    microphone_positions: Positions = micFileXML.positions
    if microphone_positions is None:
        raise HarkToolError('Positions not found in microphone xml')

    output_channel_indices = select_output_channels(source_positions, microphone_positions)

    sampling_rate = source_config.sampling_rate
    fft_bin_size = source_config.n_fft
    fft_half_size = (fft_bin_size // 2 + 1)

    tf_vector = numpy.zeros(
        (len(source_positions), len(microphone_positions), fft_half_size), dtype=numpy.complex64)

    microphone_positions_matrix = numpy.array([
        microphone_positions[index].as_coordinate(CoordinateSystem.Cartesian) for index in output_channel_indices
    ])

    # sort source positions
    source_positions.sort_positions_by_id()
    source_positions_matrix = numpy.array([
        position.as_coordinate(CoordinateSystem.Cartesian) for position in source_positions.positions
    ])

    for src_index, _ in enumerate(source_positions.positions):
        logger.info('progress {}/{}'.format(src_index + 1, len(source_positions)))

        src_position_matrix = source_positions_matrix[src_index, :]

        mic_distances = microphone_positions_matrix - src_position_matrix
        mic_distances = numpy.sqrt(
            numpy.sum(mic_distances * mic_distances, axis=1).astype(numpy.float32))
        mic_deltas = sampling_rate / SOUND_SPEED * mic_distances

        max_mic_delta = numpy.max(mic_deltas)

        offset = max(0, max_mic_delta + direct_length - fft_bin_size)

        theta = 2 * numpy.pi * \
            (numpy.arange(fft_half_size) / fft_bin_size) * \
            (mic_deltas - offset)[:, None]

        tf_vector[src_index, :, :] = numpy.exp(-theta * 1.j)

    if output_type in (TFType.LOC, TFType.LOC_SEP):
        loc_tfs = tf_vector[:, :, :fft_half_size]
    else:
        loc_tfs = None

    if output_type in (TFType.SEP, TFType.LOC_SEP):
        print('Normalization Type', file=message)
        print('- Source      [{}]'.format(True), file=message)
        print('- Microphone  [{}]'.format(True), file=message)
        print('- Frequency   [{}]'.format(True), file=message)
        if not normalize_source or not normalize_microphone or not normalize_frequency:
            logger.warn(
                'ignored --normalize*=0 option. calctfgeo always generates normalized value.')

        sep_tfs = tf_vector[:, :, :fft_half_size].conj()
    else:
        sep_tfs = None

    print('Direct sound Length {}'.format(direct_length), file=message)

    # setup worker instance
    tf = TransferFunction(positions=source_positions, microphones=microphone_positions,
                          config=source_config, neighbors=source_neighbors, loc_tfs=loc_tfs, sep_tfs=sep_tfs)
    TransferFunctionParser.as_zipfile(tf, output_file)

# end of file
