import time

import numpy
import torch

import hark
import hark.base
import hark.node

from . import VIVIT, Stack_Batch, Stack_Frame, vmvivit


class User_node_VIVIT(hark.node.PythonNode):
    def __init__(self):
        super().__init__(["INPUT","BATCH_SIZE"], "OUTPUT")
        self.count = 0
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.batch_len = 64
        #self.device = torch.device("cpu")
        print(self.device)
        self.a = Stack_Frame.Stack_Frame()
        self.b = Stack_Batch.Stack_Batch()
        self.c = VIVIT.VIVIT(self.device)

    def forward(self, **kwargs):
        if "BATCH_SIZE" in kwargs:
            if isinstance(kwargs["BATCH_SIZE"], int) and kwargs["BATCH_SIZE"] % 4 ==0 :
                if kwargs["BATCH_SIZE"] != self.batch_len:
                   self.batch_len = kwargs["BATCH_SIZE"]
            else:
                print("Batch size must be a natural multiple of 4")
                raise RangeException

        self.x = numpy.angle(kwargs["INPUT"])
        self.x = self.a.forward(self.x,self.device)
        self.x = self.b.forward(self.x,self.batch_len)
        peak, power = self.c.forward(self.x)

        self.count += 1
        if len(peak) > 0:
            srcs = [hark.base.harklib.Source() for s in range(self.c.num_source)]
            for i,s in enumerate(srcs):
                    s.id = 0
                    j = peak[i]
                    print(f"deg={j*5}")
                    s.x = [numpy.cos(numpy.radians(j*5)), numpy.sin(numpy.radians(j*5))]
                    s.power = power[i]
            return {"OUTPUT":srcs}
        
        else:
            return {}
