#!/usr/bin/python
# -*- coding: utf-8 -*-
"""

harktool/libharkio3/positions.py

"""

import typing
import xml.etree.ElementTree as ET
from enum import Enum
from xml.etree.ElementTree import Element

from .defs import CoordinateSystem, PositionsType, UseTagType
from .error import LibHarkIOError
from .position import Position, PositionParser

import pydantic


class Positions(pydantic.BaseModel):
    class Config:
        arbitrary_types_allowed = True

    type: PositionsType
    coordinate_system: CoordinateSystem = CoordinateSystem.Cartesian
    frame: int = 0
    positions: list[Position] = []

    @pydantic.model_serializer
    def serialize_dt(self) -> dict[str, str]:
        return {
            'type': str(self.type),
            'coordinate_system': str(self.coordinate_system),
            'frame': self.frame,
            'positions': [p.model_dump() for p in self.positions],
        }
        
    # def __init__(self, type: PositionsType, coodinate_system: CoordinateSystem = CoordinateSystem.Cartesian, frame: int = 0, positions: list[Position] = None):
    #     super().__init__()

    #     self._type = type
    #     self._coordinate_system = coodinate_system
    #     self._frame = frame
    #     self._positions = positions if positions else []

    # @property
    # def type(self) -> PositionsType:
    #     return self._type

    # @property
    # def positions(self) -> list[Position]:
    #     return self._positions

    # @property
    # def coordinate_system(self) -> CoordinateSystem:
    #     return self._coordinate_system

    def __iadd__(self, position: Position) -> 'Positions':
        self.positions.append(position)
        return self

    def __len__(self) -> int:
        return len(self.positions)

    def __getitem__(self, index: int) -> Position:
        return self.positions[index]

    def query_position_by_id(self, id: int) -> Position:
        return next((p for p in self.positions if p.position_id == id), None)

    def sort_positions_by_id(self) -> 'Positions':
        self.positions = sorted(self.positions, key=lambda p: p.position_id)


class PositionsParser:
    @classmethod
    def as_element(cls, positions: Positions) -> Element:
        element = ET.Element('positions', dict(
            type=str(positions.type.value),
            coordinate=str(positions.coordinate_system.value),
        ))

        for position in positions.positions:
            element.append(PositionParser.as_element(position))

        return element

    @classmethod
    def from_element(cls, element: Element) -> Positions:
        if element.tag != 'positions':
            raise LibHarkIOError(f'Unexpected element <{element.tag}>.')

        positions_type = PositionsType(element.attrib.get('type'))
        coodinate_system = CoordinateSystem(element.attrib.get('coordinate'))

        new_positions = Positions(
            type=positions_type, coodinate_system=coodinate_system)
        for child in element:
            if child.tag == 'position':
                position = PositionParser.from_element(
                    child, coodinate_system=coodinate_system)
                new_positions += position
            else:
                raise LibHarkIOError(f'Unexpected element <{child.tag}>.')

        return new_positions

    @classmethod
    def append_channels_element(cls, positions: Positions, element: Element) -> Element:
        channels_element = element
        for position in positions.positions:
            if position.channels_use == UseTagType.USE_TAG:
                channels_element = ET.Element('channels', {'use': ''.join(f'{c};' for c in position.channels)})
                element.append(channels_element)
                break

        for position in positions.positions:
            if position.channels_use == UseTagType.ID_TAG:
                child = ET.Element('channel', {'id': position.position_id, 'use': ''.join(f'{c};' for c in position.channels)})
                channels_element.append(child)
        
        return element        
    
    @classmethod
    def from_channels_element(cls, position: Positions, element: Element) -> Positions:
        pass

    @classmethod
    def from_file(cls, file: str | typing.IO) -> Positions:
        if isinstance(file, str):
            with open(file, 'r') as fp:
                return PositionsParser.from_file(fp)

        root = ET.parse(file).getroot()
        positions_elements = root.findall('positions')
        if len(positions_elements) == 0:
            raise LibHarkIOError('No <positions> in XML')
        elif len(positions_elements) > 1:
            raise LibHarkIOError('Duplicate <positions> in XML')
        else:
            return PositionsParser.from_element(positions_elements.pop())

# end of file
