import os
import time

import numpy
import torch
import torch.nn as nn

from . import vmvivit
from collections import deque


class VIVIT:
    def __init__(self, device):
        self.activation = nn.Sigmoid()
        model_loc = "{}/{}".format(os.path.dirname(__file__), "vm_model.pth")
        global vmvivit
        self.vmvivit = vmvivit.Vit().to(device)
        self.vmvivit.load_state_dict(torch.load(model_loc, map_location=device))
        self.vmvivit.eval()
        self.num_source = 2
        self.stack = deque()
        
    def peak_search(self):
        prev_output = torch.roll(self.output, 1)
        next_output = torch.roll(self.output, -1)
        peak_mask = (self.output > prev_output) & (self.output > next_output)
        peak_indices = torch.where(peak_mask)[0]
        peak_values = self.output[peak_indices]
        sorted_relative_indices = torch.argsort(peak_values, descending=True)[:self.num_source]
        sorted_indices = peak_indices[sorted_relative_indices].tolist()
        self.peak = sorted_indices
        self.power = peak_values[sorted_relative_indices].tolist()

    def forward(self,x):
        if torch.is_tensor(x):
            with torch.inference_mode():
                self.model_output = self.vmvivit(x)
                self.model_output = self.activation(self.model_output)
            for i in range(self.model_output.size(0)):
                self.stack.append(self.model_output[i])
            self.output = self.stack.popleft()
            self.peak_search()

            return self.peak, self.power
    
        else:
            if x == 1:
                if len(self.stack) > 0:
                    self.output = self.stack.popleft()
                    self.peak_search()

                    return self.peak, self.power
                
                else:
                    return [], []
            
            if x == 0:
                return [], []
       
        

