#!/usr/bin/python
# -*- coding: utf-8 -*-
"""

harktool/app/calctf.py

"""

import scipy

import typing

import numpy

from .defs import TFType
from .error import HarkToolError
from .workset import SoundData, WorkingDirectory


class IRInfo:
    def __init__(self, tf_offset: int, cut_from: int, cut_to: int):
        super().__init__()
        self._tf_offset = tf_offset
        self._cut_from = cut_from
        self._cut_to = cut_to
        self._irvec = None

    @property
    def tf_offset(self) -> int:
        return self._tf_offset

    @property
    def cut_from(self) -> int:
        return self._cut_from

    @property
    def cut_to(self) -> int:
        return self._cut_to


class IRVector:
    def __init__(self):
        super().__init__()
        self.loc = None
        self.sep = None
        self.loc_irvec = None
        self.sep_irvec = None


class IRVectorExtractor:
    def __init__(self, output_type: TFType, direct_length: int, reverb_length: int, peak_search: tuple[int, int], fft_window: int,
                 logger: typing.Any = None, message_buffer: typing.IO = None):
        super().__init__()
        self._output_type = output_type
        self._direct_length = direct_length
        self._reverb_length = reverb_length
        self._peak_search_from, self._peak_search_to = peak_search
        self._fft_window = fft_window
        self._logger = logger
        self._message_buffer = message_buffer

    def __call__(self, impulse_responses: SoundData) -> IRVector | list[IRVector]:

        frame_list = []
        peak_list = []
        for impulse_response in impulse_responses:
            # frame = impulse_response.waveform
            frame = impulse_response
            peak = self._search_peak(frame, self._peak_search_from - 1, self._peak_search_to - 1)

            frame_list.append(frame)
            peak_list.append(peak)

        farthest_peak = numpy.argmax(peak_list)
        minimum_frame_length = min(frame.size for frame in frame_list)

        irvec: IRVector = IRVector()
        if self._output_type in (TFType.LOC, TFType.LOC_SEP):
            irvec.loc = self._irrange_for_fft(farthest_peak, self._direct_length + 1, minimum_frame_length)
        if self._output_type in (TFType.SEP, TFType.LOC_SEP):
            irvec.sep = self._irrange_for_fft(farthest_peak, self._direct_length + 1, minimum_frame_length)

        loc_irvecs = []
        sep_irvecs = []
        for peak, frame in zip(peak_list, frame_list):
            if self._output_type in (TFType.LOC, TFType.LOC_SEP):
                loc_irvecs.append(self._ir_for_fft(
                    TFType.LOC, peak, irvec.loc, self._direct_length, frame))
            if self._output_type in (TFType.SEP, TFType.LOC_SEP):
                sep_irvecs.append(self._ir_for_fft(
                    TFType.SEP, peak, irvec.sep, self._direct_length, frame))

        irvec.loc_irvec = numpy.stack(loc_irvecs)
        irvec.sep_irvec = numpy.stack(sep_irvecs)

        return irvec

    def _search_peak(self, irVect: numpy.ndarray, start: int, end: int) -> int:
        if start < 0 or end < 0:
            return numpy.abs(irVect).argmax()
        else:
            return numpy.abs(irVect[start:end]).argmax() + start

    def _irrange_for_fft(self, farthest: int, margin: int, limit: int) -> IRInfo:
        start = max(0, farthest + margin - self._fft_window)
        end = min(limit, start + self._fft_window)

        return IRInfo(tf_offset=start, cut_from=start, cut_to=end)

    def _ir_for_fft(self, tftype: TFType, peak: int, irinfo: IRInfo, sound_length: int, src_ir_vect: list[float]) -> list[float]:
        dest = numpy.zeros(self._fft_window, dtype=numpy.float32)
        dest[:(irinfo.cut_to - irinfo.cut_from)
             ] = src_ir_vect[irinfo.cut_from:irinfo.cut_to]
        deltaPeak = 0
        if tftype == TFType.LOC:
            deltaPeak = peak - irinfo.cut_from
            dest[deltaPeak + 1 + sound_length:self._fft_window] = 0.

        return dest


class TFExtractor:
    def __init__(self, output_type: TFType,
                 normalize_source: bool, normalize_microphone: bool, normalize_frequency: bool,
                 reset_microphone: bool, 
                 logger: typing.Any = None, message_buffer: typing.IO = None):
        super().__init__()
        self._output_type = output_type
        self._normalize_source = normalize_source
        self._normalize_microphone = normalize_microphone
        self._normalize_frequency = normalize_frequency
        self._reset_microphone = reset_microphone
        self._logger = logger
        self._message_buffer = message_buffer

    def __call__(self, irvectors: list[IRVector]) -> tuple[list, list]:
        loc_tfs = None
        if self._output_type in (TFType.LOC, TFType.LOC_SEP):
            loc_tfs = [self._generate_tf_music(irvec.loc_irvec) for irvec in irvectors]

        sep_tfs = None
        if self._output_type in (TFType.SEP, TFType.LOC_SEP):
            sep_tfs = [self._generate_tf_gss(
                irvec.sep_irvec) for irvec in irvectors]

        return loc_tfs, sep_tfs

    def _generate_tf_music(self, irvect):
        tfvect = numpy.fft.rfft(irvect)
        return tfvect

    def _generate_tf_gss(self, irvect):
        normalized_ir_vect = irvect / (irvect * irvect).sum(axis=1)[:, numpy.newaxis]
        tfvect = numpy.fft.rfft(normalized_ir_vect)

        tfvect = self._normalize_tf(tfvect)
        return tfvect

    def _normalize_tf(self, source: numpy.ndarray) -> numpy.ndarray:
        x0 = source.copy()
        xp = numpy.sqrt(x0 * x0.conj())

        if self._normalize_source:
            x0 = x0 / xp.sum(axis=(1, 2))[:, None, None]
        if self._normalize_microphone:
            x0 = x0 / xp.sum(axis=(0, 2))[None, :, None]
        if self._normalize_frequency:
            x0 = x0 / xp.sum(axis=(0, 1))[None, None, :]

        return x0


class IRFromTSP:
    _DISTANCE_OFFSET = -50
    _HSIZE = 4096

    def __init__(self, sampling_rate: int, samples_by_frame: int, sync_add_frequency: int, sync_add_offset: int, signal_maximum: float):
        super().__init__()
        self._sampling_rate = sampling_rate
        self._samples_by_frame = samples_by_frame
        self._sync_add_frequency = sync_add_frequency
        self._sync_add_offset = sync_add_offset
        self._signal_maximum = signal_maximum

    def __call__(self, tsp_original_signal: SoundData, tsp_original_channel: int, input_signal: SoundData, output_channels: list[int], search_offset: tuple[int, int]) -> SoundData:
        if tsp_original_signal.sampling_rate != self._sampling_rate:
            raise HarkToolError('Sampling rate mismatch')

        # tsp_waveform
        tsp_waveform = tsp_original_signal.waveform
        if self._signal_maximum <= 0:
            signal_maximum = numpy.abs(tsp_waveform).max()
        else:
            signal_maximum = self._signal_maximum
        tsp_fft = numpy.fft.rfft(tsp_waveform[:, tsp_original_channel] / signal_maximum, self._samples_by_frame)

        # input_waveform
        if input_signal.frame_count < self._sync_add_frequency * self._samples_by_frame + self._sync_add_offset:
            raise HarkToolError(
                'Wave length({}) is short when syncAddFreq = {}.'.format(input_signal.frame_count, self._sync_add_frequency))

        find_sync_add_offset: bool = (max(search_offset) >= 0)
        if find_sync_add_offset:
            search_from, search_to = search_offset
            search_to = min(search_to, input_signal.frame_count)

            calculated_offsets = []
            for channel_index in output_channels:
                if search_to < search_from + tsp_original_signal.frame_count:
                    if search_from >= tsp_original_signal.frame_count:
                        search_from -= tsp_original_signal.frame_count
                    elif search_to + tsp_original_signal.frame_count <= input_signal.frame_count:
                        search_to += tsp_original_signal.frame_count
                    else:
                        raise HarkToolError(
                            'Search length is under tsp size. But there is not enough area on impulse.')

                calculated_offset = self._correlation_on_fft(input_signal.waveform[:, channel_index], tsp_original_signal, search_from, search_to)
                calculated_offsets.append(calculated_offset)

            sync_add_offset = min(calculated_offsets)
        else:
            sync_add_offset = self._sync_add_offset

        output_sounds = SoundData(sampling_rate=self._sampling_rate)
        for channel_index in output_channels:
            xmean = input_signal[channel_index][sync_add_offset:][:self._samples_by_frame * self._sync_add_frequency] \
                .reshape(self._sync_add_frequency, -1).mean(axis=0)
            xfft = numpy.fft.rfft(xmean)
            # print(tsp_fft)
            # print(tsp_fft * tsp_fft.conj())
            hfft = xfft * tsp_fft.conj() / (tsp_fft * tsp_fft.conj())
            h_wave_buf = numpy.fft.irfft(hfft)

            output_sounds.append(sound=h_wave_buf[:IRFromTSP._HSIZE].real[numpy.newaxis, :], sampling_rate=self._sampling_rate)

        return output_sounds

    def _correlation_on_fft(self, x: numpy.ndarray, y: numpy.ndarray, start_pos: int, end_pos: int) -> int:

        search_len: int = end_pos - start_pos
        len_temp = max(len(y), search_len)

        fft_len_log = numpy.floor(numpy.log2(len_temp))
        fft_len = numpy.power(2, fft_len_log)

        x_buf = x[:start_pos][:search_len]
        y_buf = y

        x_fft = numpy.fft.rfft(x_buf, fft_len)
        y_fft = numpy.fft.rfft(y_buf, fft_len)

        fft_res = y_fft.conj() * x_fft
        res_buf = numpy.fft.irfft(fft_res, fft_len)

        cxy = numpy.abs(res_buf)
        peak_corr_idx = numpy.where(cxy == numpy.maximum.accumulate(cxy))[
            0][-10:][::-1]
        max_idx = peak_corr_idx[-1]

        for peak_corr_i in peak_corr_idx:
            if peak_corr_i < max_idx - (16384 / 4) * 3 and cxy[peak_corr_i] > cxy[max_idx] / 2:
                max_idx = peak_corr_i
            if peak_corr_i == 0:
                return 0

        return start_pos + max_idx + IRFromTSP._DISTANCE_OFFSET


class IRFromMouthTSP:
    def __init__(self,
                 frame_start: int,  # s
                 frame_end: int,  # e
                 frame_head_margin: int,  # head_margin
                 add_white_noise: float,  # white noise level
                 fft_window: int,  # nfft
                 fft_hop: int,  # hop
                 sampling_rate: int,  # samplingrate
                 ):
        super().__init__()
        self._frame_start = frame_start
        self._frame_end = frame_end
        self._frame_head_margin = frame_head_margin
        self._add_white_noise = add_white_noise
        self._fft_window = fft_window
        self._fft_hop = fft_hop
        self._sampling_rate = sampling_rate

    def __call__(self, input_sound: SoundData,
                 input_channel: int,
                 output_channels: list[int],  # like [0, 1, 3, 4, ...]
                 workdir: WorkingDirectory = None,
                 ) -> SoundData:

        # validate input format/size
        if input_sound.frame_count < self._frame_end:
            raise HarkToolError(
                'Wav({}) frames({}) is less than end of sample ({})'.format(input_sound, input_sound.frame_count, self._frame_end))

        if input_channel >= input_sound.channel_count:
            raise HarkToolError(
                'mouth TSP data not found: wav channel [{}] in file [{}]'.format(input_channel, input_sound))

        if self._sampling_rate and input_sound.sampling_rate != self._sampling_rate:
            raise HarkToolError('Config SamplingRate ({}) differ from Wav({}) SamplingRate ({})'.format(
                self._sampling_rate, input_sound, input_sound.sampling_rate))

        # output container
        output_waves = SoundData(sampling_rate=self._sampling_rate)

        # process
        stft = scipy.signal.ShortTimeFFT(win=scipy.signal.windows.hamming(
            self._fft_window), hop=self._fft_hop, fs=input_sound.sampling_rate, mfft=self._fft_window)
        Si = stft.stft(input_sound[input_channel][self._frame_start:self._frame_end])

        for channel_id in output_channels:
            So = stft.stft(input_sound.channel_by_channel_id(channel_id)[self._frame_start - self._frame_head_margin:self._frame_end - self._frame_head_margin])
            # print((Si.conj() * So).shape)
            Ro = (Si.conj() * So).sum(axis=1) / (Si.conj() * Si).sum(axis=1)
            # print(Ro.shape)
            YY = scipy.fft.irfft(Ro)
            # YY = stft.istft(Ro[:, None])
            YY = YY.real / self._fft_window

            output_waves.append(YY[numpy.newaxis, :])

        return output_waves

# end of file
