import itertools
import numpy as np
import os
import tqdm
from typing import Optional


def get_embedding(
    token_embeddings: np.ndarray,
    token_ids: np.ndarray,
) -> np.ndarray:
    """
    Parameters
    ----------
    - `token_embeddings`: A 2D numpy array of shape `(vocab_size, embedding_dim)`
    - `token_ids`: A 1D numpy array of shape `(sequence_length,)`

    Returns
    -------
    - A 2D numpy array `features` of shape `(sequence_length, embedding_dim)`
    such that its `i`-th row is the embedding of the token with id `token_ids[i]`.
    """
    raise NotImplementedError



def get_kqv(
    key_parameters: np.ndarray,
    query_parameters: np.ndarray,
    residual_input: np.ndarray,
    value_parameters: np.ndarray,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Parameters
    ----------
    - `key_parameters`: A 2D numpy array of shape `(embedding_dim, key_dim)`
    - `query_parameters`: A 2D numpy array of shape `(embedding_dim, key_dim)`
    - `residual_input`: A 2D numpy array of shape `(sequence_length, key_dim)`
    - `value_parameters`: A 2D numpy array of shape `(embedding_dim, value_dim)`

    Returns
    -------
    A tuple of three 2D numpy arrays:
    - `keys`: The key matrix of shape `(sequence_length, key_dim)`.
        It is the matrix product of `residual_input` and `key_parameters`.
    - `queries`: The query matrix of shape `(sequence_length, key_dim)`
        It is the matrix product of `residual_input` and `query_parameters`.
    - `values`: The value matrix of shape `(sequence_length, value_dim)`
        It is the matrix product of `residual_input` and `value_parameters`.
    """
    raise NotImplementedError


def get_attention_logits(
    keys: np.ndarray,
    queries: np.ndarray
) -> np.ndarray:
    """
    Parameters
    ----------
    - `keys`: A 2D numpy array of shape `(sequence_length, key_dim)`
    - `queries`: A 2D numpy array of shape `(sequence_length, key_dim)`

    Returns
    -------
    - A 2D numpy array `logits` of shape `(sequence_length, sequence_length)`
        such that `logits[i, j]` is the dot product of the `i`-th row of `queries`
        and the `j`-th row of `keys`, divided by the square root of `key_dim`.
    """
    raise NotImplementedError


def get_attention_weights(attention_logits: np.ndarray) -> np.ndarray:
    """
    Parameters
    ----------
    - `attention_logits`: A 2D numpy array of shape `(sequence_length, sequence_length)`

    Returns
    -------
    - A 2D numpy array `attention_weights` of shape `(sequence_length, sequence_length)`
        such that `attention_weights[i]` is the softmax of the `i`-th row of `attention_logits`. That is:
        1. we subtract the maximum value of the `i`-th row of `attention_logits` from each element of that row,
        to avoid overflow in the next step:
        2. we exponentiate each element of the `i`-th row of `attention_logits`, and
        3. we divide each element of the `i`-th row of `attention_logits` by the sum of the exponentiated elements of that row.
    """
    raise NotImplementedError


def get_same_token_attention_kq_parameters(
    apply_softmax: bool,
    vocab_size: int
) -> tuple[np.ndarray, np.ndarray]:
    """
    Outputs the key and query parameters that yield attention matrices A
    such that, up to the numerical precision given by `np.isclose`,
    ```
    A[i, j] = 1 / (number of tokens of type i), if i == j and apply_softmax,
    A[i, j] = 1, if i == j and not apply_softmax,
    A[i, j] = 0, otherwise.
    ```

    We expect the token embedding matrix to be the identity matrix.

    Parameters
    ----------
    - `apply_softmax`: A boolean indicating whether to apply softmax to the attention logits.
    - `vocab_size`: An integer representing the size of the vocabulary.

    Returns
    -------
    - `key_parameters`: A 2D numpy array of shape `(vocab_size, key_dim)`.
    - `query_parameters`: A 2D numpy array of shape `(vocab_size, key_dim)`.

    Note: `key_dim` is not specified in the function signature.
    """
    raise NotImplementedError


def get_different_token_attention_kq_parameters(
    vocab_size: int
) -> tuple[np.ndarray, np.ndarray]:
    """
    Outputs the key and query parameters that yield attention matrices A
    such that, up to the numerical precision given by `np.isclose`, the attention weight matrix (that is, the softmax of the attention logits) A has:
    ```
    A[i, j] = 0 if i == j, and
    A[i, j] = 1 / (number of tokens of type other than i) otherwise.
    ```

    We expect the token embedding matrix to be the identity matrix.

    Parameters
    ----------
    - `vocab_size`: An integer representing the size of the vocabulary.

    Returns
    -------
    - `key_parameters`: A 2D numpy array of shape `(vocab_size, key_dim)`.
    - `query_parameters`: A 2D numpy array of shape `(vocab_size, key_dim)`.

    Note: `key_dim` is not specified in the function signature.
    """
    raise NotImplementedError


def test_get_embedding(
    test_cases_path="get_embedding_test_cases.npz"
):
    if not os.path.exists(test_cases_path):
        raise FileNotFoundError(f"Test cases file not found: {test_cases_path}")

    print("Loading test cases.")
    loaded = np.load(test_cases_path)
    test_cases = []

    for key in loaded.keys():
        if key.endswith("_features"):
            d, l, v = map(int, key.split("_")[:3])
            token_embeddings = loaded[f"{d}_{l}_{v}_embedding"]
            token_ids = loaded[f"{d}_{l}_{v}_token_ids"]
            features = loaded[key]
            test_cases.extend(zip(token_embeddings, token_ids, features))

    print(f"Loaded {len(test_cases)} test cases.")

    progress_bar = tqdm.tqdm(
        desc="Testing get_embedding",
        total=len(test_cases),
        unit="test_case"
    )
    for token_embeddings, token_ids, features in test_cases:
        # Compute the features using the function
        computed_features = get_embedding(token_embeddings, token_ids)

        # Check if the computed features are close to the expected features
        assert np.array_equal(computed_features, features), f"Features mismatch for token embeddings shape {token_embeddings.shape} and token ids shape {token_ids.shape}, expected {features}, got {computed_features}"

        progress_bar.update(1)

    progress_bar.close()
    print("All tests passed for `get_embedding`.")



def test_get_kqv(test_cases_path="get_kqv_test_cases.npz"):
    if not os.path.exists(test_cases_path):
        raise FileNotFoundError(f"Test cases file not found: {test_cases_path}")
    
    print("Loading test cases.")
    loaded = np.load(test_cases_path)
    test_cases = []
    
    for key in loaded.keys():
        if key.endswith("_keys"):
            d, l = map(int, key.split("_")[:2])
            key_parameters = loaded[f"{d}_{l}_key_parameters"]
            keys = loaded[f"{d}_{l}_keys"]
            query_parameters = loaded[f"{d}_{l}_query_parameters"]
            queries = loaded[f"{d}_{l}_queries"]
            residual_input = loaded[f"{d}_{l}_residual_input"]
            value_parameters = loaded[f"{d}_{l}_value_parameters"]
            values = loaded[f"{d}_{l}_values"]
            test_cases.extend(zip(key_parameters, keys, query_parameters, queries, residual_input, value_parameters, values))

    print(f"Loaded {len(test_cases)} test cases.")

    progress_bar = tqdm.tqdm(
        desc="Testing get_kqv",
        total=len(test_cases),
        unit="test_case"
    )
    for key_parameters, keys, query_parameters, queries, residual_input, value_parameters, values in test_cases:
        # Compute the keys, queries, and values using the function
        computed_keys, computed_queries, computed_values = get_kqv(
            key_parameters=key_parameters,
            query_parameters=query_parameters,
            residual_input=residual_input,
            value_parameters=value_parameters
        )

        # Check if the computed keys, queries, and values are close to the expected ones
        assert np.allclose(computed_keys, keys), f"Keys mismatch for key parameters shape {key_parameters.shape} and residual input shape {residual_input.shape}, expected {keys}, got {computed_keys}"
        assert np.allclose(computed_queries, queries), f"Queries mismatch for query parameters shape {query_parameters.shape} and residual input shape {residual_input.shape}, expected {queries}, got {computed_queries}"
        assert np.allclose(computed_values, values), f"Values mismatch for value parameters shape {value_parameters.shape} and residual input shape {residual_input.shape}, expected {values}, got {computed_values}"

        progress_bar.update(1)

    progress_bar.close()
    print("All tests passed for `get_kqv`.")


def test_get_attention_logits(
    test_cases_path="get_attention_logits_test_cases.npz"
):
    if not os.path.exists(test_cases_path):
        raise FileNotFoundError(f"Test cases file not found: {test_cases_path}")

    print("Loading test cases.")
    loaded = np.load(test_cases_path)
    test_cases = []

    for key in loaded.keys():
        if key.endswith("_attention_logits"):
            d, l = map(int, key.split("_")[:2])
            keys = loaded[f"{d}_{l}_keys"]
            queries = loaded[f"{d}_{l}_queries"]
            logits = loaded[key]
            test_cases.extend(zip(keys, queries, logits))

    print(f"Loaded {len(test_cases)} test cases.")

    progress_bar = tqdm.tqdm(
        desc="Testing get_attention_logits",
        total=len(test_cases),
        unit="test_case"
    )
    for keys, queries, logits in test_cases:
        # Compute the logits using the function
        computed_logits = get_attention_logits(keys, queries)

        # Check if the computed logits are close to the expected logits
        assert np.allclose(computed_logits, logits), f"Logits mismatch for keys shape {keys.shape} and queries shape {queries.shape}"

        progress_bar.update(1)

    progress_bar.close()
    print("All tests passed for `get_attention_logits`.")


def test_get_attention_weights(
    test_cases_path="get_attention_weights_test_cases.npz"
):
    if not os.path.exists(test_cases_path):
        raise FileNotFoundError(f"Test cases file not found: {test_cases_path}")

    print("Loading test cases.")
    loaded = np.load(test_cases_path)
    test_cases = []

    for key in loaded.keys():
        if key.endswith("_attention_weights"):
            l = int(key.split("_")[0])
            logits = loaded[f"{l}_attention_logits"]
            weights = loaded[key]
            test_cases.extend(zip(logits, weights))

    print(f"Loaded {len(test_cases)} test cases.")

    progress_bar = tqdm.tqdm(
        desc="Testing get_attention_weights",
        total=len(test_cases),
        unit="test_case"
    )
    for logits, weights in test_cases:
        # Compute the weights using the function
        computed_weights = get_attention_weights(logits)

        # Check if the computed weights are close to the expected weights
        assert np.allclose(computed_weights, weights), f"Weights mismatch for logits shape {logits.shape}"

        progress_bar.update(1)

    progress_bar.close()
    print("All tests passed for `get_attention_weights`.")


def test_same_token_attention(
    apply_softmax: bool,
    test_cases_path="same_token_attention_test_cases.npz",
    vocab_size: Optional[int] = None
):
    test_name = f"`get_same_token_attention_kq_parameters` with apply_softmax={apply_softmax}"
    if vocab_size is not None:
        test_name += f" and vocab_size={vocab_size}"
    
    print(f"Testing {test_name}")

    if not os.path.exists(test_cases_path):
        raise FileNotFoundError(f"Test cases file not found: {test_cases_path}")

    print("Loading test cases.")
    loaded = np.load(test_cases_path)
    test_cases = []

    for key in loaded.keys():
        if key.endswith("_token_ids"):
            l, v = map(int, key.split("_")[:2])
            if vocab_size is not None and vocab_size != v:
                continue

            apply_softmax_array = loaded[f"{l}_{v}_apply_softmax"]
            attention_matrices = loaded[f"{l}_{v}_attention_matrices"]
            token_ids = loaded[key]
            
            if apply_softmax:
                mask = apply_softmax_array
            else:
                mask = ~apply_softmax_array

            test_cases.extend(zip(attention_matrices[mask], token_ids[mask], itertools.repeat(v)))


    print(f"Loaded {len(test_cases)} test cases.")

    progress_bar = tqdm.tqdm(
        desc="Testing get_same_token_attention_kq_parameters",
        total=len(test_cases),
        unit="test_case"
    )
    for attention_matrix, token_ids, v in test_cases:
        (
            key_parameters,
            query_parameters,
        ) = get_same_token_attention_kq_parameters(
            apply_softmax=apply_softmax,
            vocab_size=v
        )
        token_embeddings = np.eye(v)
        features = get_embedding(
            token_embeddings=token_embeddings,
            token_ids=token_ids
        )

        keys, queries, values = get_kqv(
            key_parameters=key_parameters,
            query_parameters=query_parameters,
            residual_input=features,
            value_parameters=key_parameters
        )
        attention_matrix_computed = get_attention_logits(
            keys=keys,
            queries=queries
        )
        if apply_softmax:
            attention_matrix_computed = get_attention_weights(attention_matrix_computed)
            
        # Check if the computed attention matrix is close to the expected one
        assert np.allclose(attention_matrix, attention_matrix_computed), f"Attention matrix mismatch for token ids {token_ids}, expected {attention_matrix}, got {attention_matrix_computed}"

        progress_bar.update(1)

    progress_bar.close()
    print(f"All tests passed for {test_name}.")


def test_different_token_attention(
    test_cases_path="different_token_attention_test_cases.npz",
    vocab_size: Optional[int] = None
):
    test_name = "`get_different_token_attention_kq_parameters`"
    if vocab_size is not None:
        test_name += f" with vocab_size={vocab_size}"
    
    print(f"Testing {test_name}")

    if not os.path.exists(test_cases_path):
        raise FileNotFoundError(f"Test cases file not found: {test_cases_path}")

    print("Loading test cases.")
    loaded = np.load(test_cases_path)
    test_cases = []

    for key in loaded.keys():
        if key.endswith("_token_ids"):
            l, v = map(int, key.split("_")[:2])
            if vocab_size is not None and vocab_size != v:
                continue

            attention_matrices = loaded[f"{l}_{v}_attention_matrices"]
            token_ids = loaded[key]
            test_cases.extend(zip(attention_matrices, token_ids, itertools.repeat(v)))

    print(f"Loaded {len(test_cases)} test cases.")

    progress_bar = tqdm.tqdm(
        desc="Testing get_different_token_attention_kq_parameters",
        total=len(test_cases),
        unit="test_case"
    )
    for attention_matrix, token_ids, v in test_cases:
        (
            key_parameters,
            query_parameters,
        ) = get_different_token_attention_kq_parameters(
            vocab_size=v
        )
        token_embeddings = np.eye(v)
        features = get_embedding(
            token_embeddings=token_embeddings,
            token_ids=token_ids
        )

        keys, queries, values = get_kqv(
            key_parameters=key_parameters,
            query_parameters=query_parameters,
            residual_input=features,
            value_parameters=key_parameters
        )
        attention_matrix_computed = get_attention_logits(
            keys=keys,
            queries=queries
        )
        attention_matrix_computed = get_attention_weights(attention_matrix_computed)

        # Check if the computed attention matrix is close to the expected one
        assert np.allclose(attention_matrix, attention_matrix_computed), f"Attention matrix mismatch for token ids {token_ids}, expected {attention_matrix}, got {attention_matrix_computed}"

        progress_bar.update(1)

    progress_bar.close()
    print("All tests passed for `get_different_token_attention_kq_parameters`.")


if __name__ == "__main__":
    test_get_embedding()
    test_get_kqv()
    test_get_attention_logits()
    test_get_attention_weights()

    test_same_token_attention(apply_softmax=False, vocab_size=2)
    test_same_token_attention(apply_softmax=True, vocab_size=2)
    test_same_token_attention(apply_softmax=False)
    test_same_token_attention(apply_softmax=True)

    test_different_token_attention(vocab_size=2)
    test_different_token_attention()

    print("All tests passed!")