#!/usr/bin/python
# -*- coding: utf-8 -*-
"""

harktool/libharkio3/transferfunction.py

"""

import struct
import xml.etree.ElementTree as ET
import zipfile

import numpy
import pydantic

from ..libharkio3.defs import DataType
from .config import Config as _Config, ConfigParser
from .error import LibHarkIOError
from .neighbors import Neighbors, NeighborsParser
from .positions import Positions, PositionsParser
from .xml import HarkXML, HarkXMLParser

_HARKIO_TAG = "HARK1.3"
_HARKIO_MATRIX_DIM = 2

_DATATYPES_TO_DTYPE = {
    DataType.int32:   numpy.dtype('int32'),
    DataType.float32: numpy.dtype('float32'),
    DataType.complex: numpy.dtype('complex64'),
}

_DTYPE_TO_DATATYPES = dict((v, k) for k, v in _DATATYPES_TO_DTYPE.items())


class TransferFunction(pydantic.BaseModel):
    class Config:
        arbitrary_types_allowed = True
    
    positions: Positions | None = None
    microphones: Positions | None = None
    config: _Config | None = None
    neighbors: Neighbors | None = None
    loc_tfs: list[numpy.ndarray] = []
    sep_tfs: list[numpy.ndarray] = []

    @pydantic.model_serializer
    def serialize_dt(self) -> dict[str, str]:
        return {
            'positions': self.positions.model_dump(),
            'microphones': self.microphones.model_dump(),
            'neighbors': self.neighbors.model_dump(),
            'config': self.config.model_dump(),
            'loc_tfs': f'{len(self.loc_tfs)}',
            'sep_tfs': f'{len(self.sep_tfs)}',
        }

class TransferFunctionParser:
    _SIGNATURE = 'transfer function'

    @classmethod
    def as_zipfile(cls, tf: TransferFunction, path: str) -> None:
        with zipfile.ZipFile(path, 'w') as zipfp:

            root = zipfile.Path(zipfp) / 'transferFunction'
            loc_path = root / 'localization'
            sep_path = root / 'separation'

            with (root / 'whatisthis.txt').open('w') as fp:
                fp.write(TransferFunctionParser._SIGNATURE)

            xml = HarkXML(positions=tf.microphones)
            with (root / 'microphones.xml').open('wb') as fp:
                HarkXMLParser.as_file(xml, fp)

            xml = HarkXML(positions=tf.positions,
                          neighbors=tf.neighbors, config=tf.config)
            with (root / 'source.xml').open('wb') as fp:
                HarkXMLParser.as_file(xml, fp)

            for position, loc_tf in zip(tf.positions.positions, tf.loc_tfs):
                with (loc_path / f'tf{position.position_id:05d}.mat').open('wb') as fp:
                    TransferFunctionParser.write_matrix(fp, loc_tf)

            for position, sep_tf in zip(tf.positions.positions, tf.sep_tfs):
                with (sep_path / f'tf{position.position_id:05d}.mat').open('wb') as fp:
                    TransferFunctionParser.write_matrix(fp, sep_tf)

    @classmethod
    def from_zipfile(cls, path: str) -> TransferFunction:
        try:
            with zipfile.ZipFile(path, 'r') as zipfp:
                with zipfp.open('transferFunction/whatisthis.txt', 'r') as fp:
                    buffer = fp.read().decode()
                    if buffer != TransferFunctionParser._SIGNATURE:
                        raise LibHarkIOError('Signature mismatch.')

                with zipfp.open('transferFunction/source.xml', 'r') as fp:
                    root = ET.parse(fp)
                    hark_xml = root.getroot()
                    
                    for child in hark_xml:
                        if child.tag == 'positions':
                            positions = PositionsParser.from_element(child)
                        if child.tag == 'config':
                            config = ConfigParser.from_element(child)
                        if child.tag == 'neighbors':
                            neighbors = NeighborsParser.from_element(child)

                    if neighbors is not None:
                        neighbors.positions = positions

                with zipfp.open('transferFunction/microphones.xml', 'r') as fp:
                    root = ET.parse(fp)
                    hark_xml = root.getroot()
                    for child in hark_xml:
                        if child.tag == 'positions':
                            microphones = PositionsParser.from_element(child)

                loc_base = zipfile.Path(zip, 'transferFunction/localization')
                sep_base = zipfile.Path(zip, 'transferFunction/separation')
                loc_tfs = []
                sep_tfs = []
                for position in positions:
                    loc_path = loc_base / f'tf{position.position_id:05d}.mat'
                    if loc_path.is_file():
                        with loc_path.open('rb') as fp:
                            loc_tfs.append(
                                TransferFunctionParser.read_matrix(fp))

                    sep_path = sep_base / f'tf{position.position_id:05d}.mat'
                    if sep_path.is_file():
                        with sep_path.open('rb') as fp:
                            sep_tfs.append(
                                TransferFunctionParser.read_matrix(fp))

                loc_tfs = numpy.stack(loc_tfs)
                sep_tfs = numpy.stack(sep_tfs)

        except zipfile.BadZipFile as ex:
            with open(path, 'rb') as fp:
                header = fp.read(4)
                if header == b'HARK':
                    raise LibHarkIOError("If you want to convert your binary-format TFs into a zip-format (HARK>=2.1), see 'harktoolcli-conv-tf --help' (after installing harktool5).")

            raise LibHarkIOError(f"Zip open failed: file '{path}' is not in zip format") from ex
        
        except BaseException as ex:
            raise LibHarkIOError(f"Failed to open file {path}") from ex

        return TransferFunction(positions=positions, microphones=microphones, config=config, neighbors=neighbors,
                                loc_tfs=loc_tfs, sep_tfs=sep_tfs)

    @classmethod
    def read_matrix(cls, fp) -> numpy.ndarray:
        tag = fp.read(32).decode().strip()
        if tag != _HARKIO_TAG:
            raise LibHarkIOError(f'Unexpected tag {tag}')

        data_type = fp.read(32).decode().strip()
        try:
            dtype = DataType(data_type)
        except BaseException as ex:
            raise LibHarkIOError(f'Unexpected data_type {data_type}') from ex
        dtype = _DATATYPES_TO_DTYPE[dtype]

        dim = int.from_bytes(fp.read(4), byteorder='little')
        if dim != _HARKIO_MATRIX_DIM:
            raise LibHarkIOError(f'Unexpected dim {dim}')

        rows = int.from_bytes(fp.read(4), byteorder='little')
        cols = int.from_bytes(fp.read(4), byteorder='little')
        matrix = numpy.frombuffer(fp.read(), dtype=dtype).reshape(rows, cols)

        return matrix

    @classmethod
    def write_matrix(cls, fp, matrix: numpy.ndarray) -> None:
        v = _HARKIO_TAG.encode()
        fp.write(v + b' ' * (32 - len(v)))

        v = _DTYPE_TO_DATATYPES.get(matrix.dtype, None)
        if v is None:
            raise LibHarkIOError(f'Unsupported dtype: {matrix.dtype}')

        v = str(v.name).encode()
        fp.write(v + b' ' * (32 - len(v)))

        fp.write(struct.pack('=lll', _HARKIO_MATRIX_DIM,
                 matrix.shape[0], matrix.shape[1]))
        fp.write(matrix.tobytes())

# end of file
