#!/usr/bin/python
# -*- coding: utf-8 -*-
"""

harktool/libharkio3/neighbors.py

"""

import xml.etree.ElementTree as ET
from xml.etree.ElementTree import Element

from sklearn.neighbors import NearestNeighbors

import numpy

from .defs import CoordinateSystem, NeighborAlgorithm
from .error import LibHarkIOError
from .positions import Positions

import pydantic

class Neighbors(pydantic.BaseModel):
    class Config:
        arbitrary_types_allowed = True
    
    algorithm: NeighborAlgorithm = NeighborAlgorithm.Undefined
    neighbors: dict[int, list[int]] = {}
    positions: Positions = None
    
    def find_neighbors(self, position_id: int) -> list[list[int]]:
        if self.algorithm == NeighborAlgorithm.Undefined:
            raise LibHarkIOError('algorithm is undefined')

        index = next((i for (i, p) in enumerate(
            self.positions) if p.position_id == position_id), None)
        if index is None:
            raise LibHarkIOError(
                f'Neighbor not found for position with id {position_id}')

        return self.neighbors[index]

    def calculate_neighbors(self, algorithm: NeighborAlgorithm,
                            positions: Positions, **kwargs) -> None:
        if algorithm == NeighborAlgorithm.NearestNeighbor:
            count = kwargs.get('count')
            threshold = kwargs.get('threshold')
            neighbors = self.algorithms_nearest_neighbor(positions, count, threshold)

            self.algorithm = algorithm
            self.neighbors = dict((p.position_id, neighbors[i]) for i, p in enumerate(positions.positions))

        else:
            raise LibHarkIOError('algorithm is unknown.')

    # def append_neighbor(self, position_id: int, neighbors: list[int]) -> None:
    #     try:
    #         index = self.ids.index(position_id)
    #     except ValueError:
    #         index = None

    #     if index is not None:
    #         self.neighbors[index].extend(neighbors)

    #     else:
    #         self.ids.append(position_id)
    #         self.neighbors.append(neighbors)

    def algorithms_nearest_neighbor(self, positions: Positions,
                                    count: int, threshold: float) -> numpy.ndarray:

        if count <= 0:
            raise LibHarkIOError("n-nearest neighrbos must be (n > 0)")

        if threshold <= 0:
            raise LibHarkIOError("threshold must be positive")

        xs = numpy.array([p.as_coordinate(CoordinateSystem.Cartesian)
                         for p in positions.positions])

        nn = NearestNeighbors(n_neighbors=count, radius=threshold,
                              metric=lambda x, y: numpy.linalg.norm(x - y))
        nn.fit(xs)

        distances, neighbors = nn.kneighbors(xs)
        return [n[d <= threshold] for n, d in zip(neighbors, distances)]


class NeighborsParser:
    @classmethod
    def as_element(cls, neighbors: Neighbors) -> Element:
        element = ET.Element('neighbors', attrib={
            'algorithm': str(neighbors.algorithm.value)
        })

        for position_index, neighbor_ids in neighbors.neighbors.items():
            element.append(ET.Element('neighbor', {
                'id':  str(position_index),
                'ids': ';'.join(str(i) for i in neighbor_ids),
            }))

        return element

    @classmethod
    def from_element(cls, element: Element) -> Neighbors:
        if element.tag != 'neighbors':
            raise LibHarkIOError()

        algorithm = None
        for akey, aval in element.attrib.items():
            if akey == 'algorithm':
                algorithm = NeighborAlgorithm(aval)

        neighbors = {}
        for child in element:
            if child.tag == 'neighbor':
                position_id = int(child.attrib.get('id'))
                neighbor_ids = [int(s) for s in child.attrib.get('ids').split(';') if s]
                
                neighbors[position_id] = neighbor_ids

        return Neighbors(algorithm=algorithm, neighbors=neighbors)

# end of file
