import time

import numpy
import torch


class Stack_Frame:
	def __init__(self):
		
		self.ch = 8
		self.f_bin = 256
		self.frame_len = 26
		self.interbal = 4
		self.stack_len = self.frame_len * self.interbal 
		self.stack = []
		self.count = 0

	def forward(self,x,device):
		x = x[:,:self.f_bin]
		x = x - x[0, :][numpy.newaxis,:]
		x = torch.from_numpy(x[1:]).to(device)
		self.stack.append(x.unsqueeze(0))
		self.count += 1
		if self.count < self.stack_len:

			return 0

		if len(self.stack) == self.stack_len:

			self.output = torch.cat(self.stack, dim = 0)
			self.output = self.output.view(self.frame_len, self.interbal, self.ch-1, self.f_bin)
			self.output = torch.permute(self.output ,(1,0,2,3))
			self.stack = self.stack[self.interbal:]

			return self.output
		
		else:

			return 1
