#!/usr/bin/python
# -*- coding: utf-8 -*-
"""

harktool/app/workset.py

"""

import typing
from datetime import datetime
from pathlib import Path

import soundfile

import numpy

from .error import HarkToolError


class WorkingDirectory:
    DEFAULT_BASE_PATH = '.harktool-temp'

    def __init__(self, path: str = None, timestamp: datetime = None, remove_on_exit: bool = False):
        super().__init__()
        self._base_path = Path(
            WorkingDirectory.DEFAULT_BASE_PATH if path is None else path)
        self._timestamp = datetime.now() if timestamp is None else timestamp
        self._remove_on_exit = remove_on_exit  # ToDo: not working

        self._timestamp_str = self._timestamp.strftime('%Y%m%d%H%M%S')
        self._working_path = self._base_path / self._timestamp_str

    def tsp_path(self, source: str, index: int, mode: str) -> Path:
        tsp_dir = self._working_path / 'tsp' / source
        return tsp_dir / f'float_ch{index:05d}.flt'


class SoundData:
    """
    複数チャネルの音声データを管理する
    原則オンメモリで保持するが、workdir が与えられている場合、適宜な一時ファイルを作成して保存する

    チャネルを識別する方法は２つある
    - チャネル通番を識別する index 。index は 0-origin の通番
    - ID番号で識別する channel_id 。channel_id は unique な 0 以上の整数値を持つ

    """
    def __init__(self, sound: numpy.ndarray=None, channel_ids:list[int]=None, sampling_rate: int=None):
        """_summary_

        Args:
            sound (numpy.ndarray, optional): _description_. Defaults to None.
            channels_indices (list[int], optional): _description_. Defaults to None.
            sampling_rate (int, optional): _description_. Defaults to None.
        """        
        super().__init__()

        channel_count = 0 if sound is None else sound.shape[0]
        if channel_ids is None:
            channel_ids = list(range(channel_count))

        self._sound = sound
        self._channel_ids = channel_ids
        self._sampling_rate = sampling_rate

        # check channel_ids
        if channel_ids is not None:
            if len(channel_ids) != len(set(channel_ids)):
                raise HarkToolError('Duplicate channel ids')
            if len(channel_ids) != (0 if sound is None else sound.shape[0]):
                raise HarkToolError('Channel count mismatch between sound and channel ids')

    @property
    def waveform(self) -> numpy.ndarray:
        """_summary_

        Returns:
            numpy.ndarray: _description_
        """        
        return self._sound

    @property
    def format(self) -> tuple[int, int]:
        return self._sound.shape[0], self._sampling_rate

    @property
    def sampling_rate(self) -> int:
        return self._sampling_rate

    @property
    def channel_count(self) -> int:
        return self._sound.shape[0]

    @property
    def frame_count(self) -> int:
        return self._sound.shape[1]
    
    def channel_by_index(self, index:int) -> numpy.ndarray:
        return self._sound[index, :]
    
    def channel_by_channel_id(self, channel_id:int) -> numpy.ndarray:
        return self._sound[self._channel_ids.index(channel_id), :]

    def append(self, sound: numpy.ndarray, channel_ids:list[int]=None, sampling_rate: int = None):
        """SoundData にチャネルを追加する

        Args:
            wav (numpy.ndarray): 追加するチャネルの音声データ (sound[channel][frame])
            sampling_rate (int, optional): sound のサンプリング周波数. Defaults to None.
        """
        if sampling_rate and self._sampling_rate != sampling_rate:
            raise HarkToolError()

        channel_count = 0 if sound is None else sound.shape[0]
        if channel_ids is None:
            origin = 0 if len(self._channel_ids) == 0 else max(self._channel_ids) + 1
            channel_ids = list(range(channel_count, origin))

        if self._sound is not None:
            self._sound = numpy.r_[self._sound, sound]
        else:
            self._sound = sound
        
        self._channel_ids.extend(channel_ids)

        # check channel_ids
        if len(self._channel_ids) != len(set(self._channel_ids)):
            raise HarkToolError('Duplicate channel ids')

    def __len__(self) -> int:
        return len(self._sound)
    
    def __getitem__(self, index:int) -> numpy.ndarray:
        return self.channel_by_index(index)

    @classmethod
    def from_wav(cls, file: str | typing.IO) -> 'SoundData':
        if isinstance(file, str):
            with open(file, 'rb') as fp:
                return SoundData.from_wav(fp)

        # wav[frame_index][channel_index]
        wav, sr = soundfile.read(file, always_2d=True)
        return SoundData(sampling_rate=sr, sound=wav.T)

    @classmethod
    def from_raw(cls, file: str | typing.IO, sampling_rate:int) -> 'SoundData':
        if isinstance(file, str):
            with open(file, 'rb') as fp:
                return SoundData.from_raw(fp, sampling_rate)

        wav = numpy.fromfile(fp, dtype=numpy.float32)
        return SoundData(sampling_rate=sampling_rate, sound=wav[numpy.newaxis, :])
    
    @classmethod
    def as_raw(cls, sound: 'SoundData', file: str | typing.IO) -> None:
        if isinstance(file, str):
            with open(file, 'rb') as fp:
                SoundData.as_raw(sound, fp)
        
        file.write(sound.waveform.astype(numpy.float32).tobytes())
    
    @classmethod
    def as_wav(cls, sound: 'SoundData', file: str | typing.IO) -> None:
        soundfile.write(file, sound.waveform, sound.sampling_rate)

    @classmethod
    def as_wav(cls, sound: 'SoundData', file: str | typing.IO) -> None:
        soundfile.write(fiie=file, data=sound.waveform)

# end of file
