import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange


class VivitInputLayer(nn.Module):
    def __init__(self, split_time:int = 2, emb_dim:int = 250, num_array:int = 7, n_fft = 512, num_patch_h:int = 1, num_patch_w:int = 128,time_len=26): #split_timeは時間軸で分けたときの1タイムチューブ内>の時間フレームの個数

        super().__init__()
        self.split_time = split_time
        self.emb_dim = emb_dim
        self.num_array = num_array
        self.n_fft = n_fft
        self.num_patch_h = num_patch_h
        self.num_patch_w = num_patch_w
        self.patch_size_h = int(self.num_array // self.num_patch_h)
        self.patch_size_w = int(self.n_fft/2 // self.num_patch_w)
        self.time_len  = time_len

        #self.patch_emb_layer = nn.Conv2d(in_channels = self.split_time , out_channels = self.emb_dim, kernel_size =(self.patch_size_h, self.patch_size_w) , stride = (self.patch_size_h, self.patch_size_w))

        self.sinConv = nn.Conv2d(in_channels = self.split_time , out_channels = self.emb_dim, kernel_size =(self.patch_size_h, self.patch_size_w) , stride = (self.patch_size_h, self.patch_size_w))
        self.cosConv = nn.Conv2d(in_channels = self.split_time , out_channels = self.emb_dim, kernel_size =(self.patch_size_h, self.patch_size_w) , stride = (self.patch_size_h, self.patch_size_w))


        
        self.cls_token = nn.Parameter(torch.randn(1,1,emb_dim))

        self.pos_emb = nn.Parameter(torch.randn(1, (self.num_patch_h*self.num_patch_w)+1, emb_dim))



    def forward(self,x:torch.Tensor) -> torch.Tensor:
        """
        引数:
            x: スペクトログラム. 形状は、(B,A,F,T).
                B:バッチサイズ,  T:時間インデックス,  A:マイクアレイ数,  F:周波数インデックス

        返り値:
            z_θ:Vivitの入力. 形状は、((B,t),N,D)
                B:バッチ数, t:時間軸のトークン数, N:空間軸のトークン数, D:埋め込みベクトルの長さ
        """
        #軸の入れ替え　(B, A, F, T) -> (B,T,A,F)
        #z_00 = rearrange(x, 'B A F T -> B T A F')
        # print(z_00.size())
        z_0 = rearrange(x, 'B (T1 T2) A F -> (B T1) T2 A F', T2=self.split_time)
        #if torch.equal(z_00[0,1], z_0[0,1] ) :
        #    print("a")
        #普通のViViT------------------------------------------------

        #z_0 = self.patch_emb_layer(z_0)

        #-----------------------------------------------------------

        #Vommises---------------------------------------------------
        
        z_0_sin = self.sinConv(torch.sin(z_0))
        z_0_cos = self.cosConv(torch.cos(z_0))
        z_0 = z_0_sin + z_0_cos
        
        #-----------------------------------------------------------

        z_0 = z_0.flatten(2)

        z_0 = rearrange(z_0, ' T D N -> T N D')

        z_0 = torch.cat([self.cls_token.repeat(repeats = (z_0.size(0), 1, 1)), z_0], dim = 1)
        z_0 = z_0 + self.pos_emb
        #print(z_0.size())
        return z_0


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, emb_dim:int=250, head:int=5, dropout:float=0.2):
        """
        引数:
            emb_dim: 埋め込み後のベクトルの長さ
            head: ヘッドの数
            dropout: ドロップアウト率
        """
        super().__init__()
        self.head = head
        self.emb_dim = emb_dim
        self.head_dim = emb_dim // head
        self.sqrt_dh = self.head_dim**0.5 # D_hの二乗根。qk^Tを割るための係数

        # 入力をq,k,vに埋め込むための線形層。 [式(6)]
        self.w_q = nn.Linear(emb_dim, emb_dim, bias=False)
        self.w_k = nn.Linear(emb_dim, emb_dim, bias=False)
        self.w_v = nn.Linear(emb_dim, emb_dim, bias=False)

        self.attn_drop = nn.Dropout(dropout)

        # MHSAの結果を出力に埋め込むための線形層。[式(10)]
        ## 式(10)にはないが、実装ではドロップアウト層も用いる
        self.w_o = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.Dropout(dropout)
        )

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """
        引数:
            z: MHSAへの入力。形状は、(B, N, D)。
                B: バッチサイズ、N:トークンの数、D:ベクトルの長さ

        返り値:
            out: MHSAの出力。形状は、(B, N, D)。[式(10)]
                B:バッチサイズ、N:トークンの数、D:埋め込みベクトルの長さ
        """

        batch_size, num_patch, _ = z.size()

        # 埋め込み [式(6)]
        ## (B, N, D) -> (B, N, D)
        q = self.w_q(z)
        k = self.w_k(z)
        v = self.w_v(z)

        # q,k,vをヘッドに分ける [式(10)]
        ## まずベクトルをヘッドの個数(h)に分ける
        ## (B, N, D) -> (B, N, h, D//h)
        q = q.view(batch_size, num_patch, self.head, self.head_dim)
        k = k.view(batch_size, num_patch, self.head, self.head_dim)
        v = v.view(batch_size, num_patch, self.head, self.head_dim)

        ## Self-Attentionができるように、
        ## (バッチサイズ、ヘッド、トークン数、パッチのベクトル)の形に変更する
        ## (B, N, h, D//h) -> (B, h, N, D//h)
        q = q.transpose(1,2)
        k = k.transpose(1,2)
        v = v.transpose(1,2)

        # 内積 [式(7)]
        ## (B, h, N, D//h) -> (B, h, D//h, N)
        k_T = k.transpose(2, 3)
        ## (B, h, N, D//h) x (B, h, D//h, N) -> (B, h, N, N)
        dots = (q @ k_T) / self.sqrt_dh
        ## 列方向にソフトマックス関数

        attn =F.softmax(dots,dim=-1)
        ## ドロップアウト
        attn = self.attn_drop(attn)
        # 加重和 [式(8)]
        ## (B, h, N, N) x (B, h, N, D//h) -> (B, h, N, D//h)
        out = attn @ v
        ## (B, h, N, D//h) -> (B, N, h, D//h)
        out = out.transpose(1, 2)
        ## (B, N, h, D//h) -> (B, N, D)
        out = out.reshape(batch_size, num_patch, self.emb_dim)

        # 出力層 [式(10)]
        ## (B, N, D) -> (B, N, D)
        out = self.w_o(out)
        return out

class VivitEncoderBlock(nn.Module):
    def __init__(self, emb_dim:int=250, head:int=5, hidden_dim:int=250, dropout: float=0.5):
        """
        引数:
            emb_dim: 埋め込み後のベクトルの長さ
            head: ヘッドの数
            hidden_dim: Encoder BlockのMLPにおける中間層のベクトルの長さ
                        原論文に従ってemb_dimの4倍をデフォルト値としている
            dropout: ドロップアウト率
        """
        super().__init__()
        # 1つ目のLayer Normalization [2-5-2項]
        self.ln1 = nn.LayerNorm(emb_dim)
        # MHSA [2-4-7項]
        self.msa = MultiHeadSelfAttention(
        emb_dim=emb_dim, head=head,
        dropout = dropout,
        )
        # 2つ目のLayer Normalization [2-5-2項]
        self.ln2 = nn.LayerNorm(emb_dim)
        # MLP [2-5-3項]
        self.mlp = nn.Sequential(
            nn.Linear(emb_dim, hidden_dim),
            nn.GELU(),
            #nn.Dropout(dropout),
            nn.Linear(hidden_dim, emb_dim),
            #nn.Dropout(dropout)
        )

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        """
        引数:
            z: Encoder Blockへの入力。形状は、(B, N, D)
                B: バッチサイズ、N:トークンの数、D:ベクトルの長さ

        返り値:
            out: Encoder Blockへの出力。形状は、(B, N, D)。[式(10)]
                B:バッチサイズ、N:トークンの数、D:埋め込みベクトルの長さ
        """
        # Encoder Blockの前半部分 [式(12)]

        out = self.msa(self.ln1(z)) + z
        # Encoder Blockの後半部分 [式(13)]
        out = self.mlp(self.ln2(out)) + out
        return out

class Vivit_temp_InputLayer(nn.Module):

    def __init__(self,split_time:int = 2, emb_dim:int = 250, time_len = 26):

        super().__init__()

        self.split_time = split_time
        self.cls_token_tempo = nn.Parameter(torch.randn(1,1,emb_dim))
        self.time_len = time_len

    def forward(self,z: torch.Tensor) -> torch.Tensor:
        #print(z.size())
        z = z[:,0]
        #print(z.size())
        z = rearrange(z, '(B T) D -> B T D', T = self.time_len // self.split_time)
        #print(z.size())
        z = torch.cat([self.cls_token_tempo.repeat(repeats = (z.size(0), 1, 1)), z], dim = 1)
        #print(z.size())
        return z

def temp_input_ave_pooling(z,split_time:int = 3, emb_dim:int = 250, time_len = 81):
    z = z[:,0]
    #print(z.size())
    z = rearrange(z, '(B T) D -> B T D', T = time_len // split_time)
    #print(z.size())
    #z = torch.cat([self.cls_token_tempo.repeat(repeats = (z.size(0), 1, 1)), z], dim = 1)
    #print(z.size())
    return z

class Vit(nn.Module):
    def __init__(self, num_classes:int=72, split_time:int = 2, emb_dim:int = 250, num_array:int = 7, n_fft = 512, num_patch_h:int =1 , num_patch_w:int = 128, num_blocks:int =8 ,head:int=5, hidden_dim:int=250, dropout: float=0.5,time_len=26):
        """
        引数:
            in_channels: 入力画像のチャンネル数
            num_classes: 画像分類のクラス数
            emb_dim: 埋め込み後のベクトルの長さ
            num_patch_row: 1辺のパッチの数
            image_size: 入力画像の1辺の大きさ。入力画像の高さと幅は同じであると仮定
            num_blocks: Encoder Blockの数
            head: ヘッドの数
            hidden_dim: Encoder BlockのMLPにおける中間層のベクトルの長さ
            dropout: ドロップアウト率
        """
        super().__init__()
        # Input Layer [2-3節]
        self.input_layer = VivitInputLayer(
            split_time,
            emb_dim,
            num_array,
            n_fft,
            num_patch_h,
            num_patch_w)

        # Encoder。Encoder Blockの多段。[2-5節]
        self.encoder_pos = nn.Sequential(*[
            VivitEncoderBlock(
                emb_dim=emb_dim,
                head=head,
                hidden_dim=hidden_dim,
                dropout = dropout
            )
            for _ in range(num_blocks)])

        #self.extract_layer =  Vivit_temp_InputLayer()
        self.pos_emb = nn.Parameter(torch.randn(1,time_len // split_time , emb_dim))

        self.encorder_temp = nn.Sequential(*[
            VivitEncoderBlock(
                emb_dim=emb_dim,
                head=head,
                hidden_dim=hidden_dim,
                dropout = dropout
            )
            for _ in range(num_blocks)])

        #self.extract_layer =  Vivit_temp_InputLayer()


        # MLP Head [2-6-1項]
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(emb_dim),
            nn.Linear(emb_dim, num_classes)

        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        引数:
            x: ViTへの入力画像。形状は、(B, C, H, W)
                B: バッチサイズ、C:チャンネル数、H:高さ、W:幅

        返り値:
            out: ViTの出力。形状は、(B, M)。[式(10)]
                B:バッチサイズ、M:クラス数
        """
        # Input Layer [式(14)]
        ## (B, C, H, W) -> (B, N, D)
        ## N: トークン数(=パッチの数+1), D: ベクトルの長さ
        out = self.input_layer(x)

        # Encoder [式(15)、式(16)]
        ## (B, N, D) -> (B, N, D)
        out = self.encoder_pos(out)

        #out = self.extract_layer(out)

        #out = self.encorder_temp(out)
        #print(out.size())
        out = rearrange(out[:,0], '(B T) D -> B T D', B=x.size(0))
        out = out + self.pos_emb

        out = self.encorder_temp(out)
        cls_token = out.mean(dim=1)
        # MLP Head [式(17)]
        ## (B, D) -> (B, M)
        pred = self.mlp_head(cls_token)
        return pred

