import math
import numpy as np

from model import apply_decoder

# Lower this if you run into memory issues
BATCH_SIZE = 1024


def get_permutations(size: int) -> np.ndarray:
    """
    Generate all permutations of integers from 0 to size - 1.

    Parameters:
    size (int): The size of the permutation.

    Returns:
    np.ndarray: An array of shape `(math.factorial(size), size)` containing all permutations.
    """
    raise NotImplementedError


def get_valid_next_token_mask(permutations: np.ndarray) -> np.ndarray:
    """
    Given an integer array `permutations` of shape `(batch_size, sequence_length)` of permutations of `range(sequence_length)`, return a boolean array `valid_next_token_mask` of shape `(batch_size, sequence_length, sequence_length)`, where for each `i in range(sequence_length)`: `valid_next_token_mask[:, i, :]` is True at positions corresponding to token IDs that have not yet appeared in `token_ids[:, :i+1]`, and False otherwise.
    """
    raise NotImplementedError


def get_set_complement_parameters(size: int, weight_sharing: bool = False) -> dict:
    """
    Get parameters for a single layer, attention only, single head, decoder-only transformer that solves the following task: given a repetition-free sequence of integers from 0 to size - 1 of length at most size -1, output a uniform distribution over the integers that have not yet appeared in the sequence.

    Since using softmax to get next token probabilities we cannot get zeros, by solving the task we mean that for all inputs, the probability mass assigned to invalid next tokens is less that 1%.
    
    Parameters
        size (int): The size of the integer set.
        weight_sharing (bool): Whether to use weight sharing between token embeddings and unembedding matrix. Default is False.
        
    Returns
        dict: A dictionary containing the model parameters. The dictionary has the following keys:
            - `key_parameters`: A 2D numpy array of shape `(embedding_dim, key_dim)`
            - `output_parameters`: A 2D numpy array of shape `(value_dim, embedding_dim)`
            - `query_parameters`: A 2D numpy array of shape `(embedding_dim, key_dim)`
            - `token_embeddings`: A 2D numpy array of shape `(size, embedding_dim)`
            - `unembedding`: A 2D numpy array of shape `(embedding_dim, size)`. This key is only present if `weight_sharing` is False.
            - `value_parameters`: A 2D numpy array of shape `(embedding_dim, value_dim)`
    """
    raise NotImplementedError




def test_get_permutations():
    for size in range(1, 10):
        permutations = get_permutations(size)

        assert len(permutations) == math.factorial(size)

        unique = np.unique(permutations, axis=0)
        assert len(unique) == len(permutations)

        sorted = np.sort(unique, axis=1)
        assert np.all(sorted == np.arange(size))


def test_get_valid_next_token_mask():
    for size in range(1, 10):
        permutations = get_permutations(size)
        valid_token_mask = get_valid_next_token_mask(permutations)

        assert valid_token_mask.shape == (len(permutations), size, size)

        assert np.all(valid_token_mask.sum(axis=-1) == (size - np.arange(1, size + 1)))

        assert np.all(~valid_token_mask[
            np.arange(len(permutations))[:, None],
            np.arange(size)[None, :],
            permutations
        ])
        
        assert np.all(
            valid_token_mask[:, :-1] & valid_token_mask[:, 1:] == valid_token_mask[:, 1:]
        )


def test_get_set_complement_parameters(weight_sharing: bool = False):
    for size in range(2, 10):
        parameters = get_set_complement_parameters(size)
        permutations = get_permutations(size)

        for i in range(0, len(permutations), BATCH_SIZE):
            batch_permutations = permutations[i:i + BATCH_SIZE]

            logits = apply_decoder(parameters, batch_permutations, weight_sharing=weight_sharing)
            valid_next_token_mask = get_valid_next_token_mask(batch_permutations)

            next_token_probabilities = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
            next_token_probabilities /= np.sum(next_token_probabilities, axis=-1, keepdims=True)

            invalid_token_probabilities = next_token_probabilities[:, :-1] * (~valid_next_token_mask[:, :-1])
            total_invalid_token_probability = np.sum(invalid_token_probabilities, axis=-1)

            assert np.max(total_invalid_token_probability) < 0.01


if __name__ == "__main__":
    test_get_permutations()
    test_get_valid_next_token_mask()
    test_get_set_complement_parameters()
    test_get_set_complement_parameters(weight_sharing=True)
    print("All tests passed. Congratulations!")