import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from math import e, pi, sin, cos, log
import sys
from lib6003.fft import fft, ifft
j = 1j

from lib6003.audio import *

def stft(x, window_size, step_size, sample_rate):
    # return a Short-Time Fourier Transform of x, using the specified window
    # size and step size.
    # return your result as a list of lists, where each internal list represents
    # the DFT coefficients of one window.  I.e., output[n][k] should represent
    # the kth DFT coefficient from the nth window.
    pass


def k_to_hz(k, window_size, step_size, sample_rate):
    # return the frequency in Hz associated with bin number k in an STFT with
    # the parameters given above.
    pass


def hz_to_k(freq, window_size, step_size, sample_rate):
    # return the k value associated with the given frequency in Hz, in an STFT
    # with the parameters given above, rounded to the nearest integer.
    pass


def timestep_to_seconds(i, window_size, step_size, sample_rate):
    # return the real-world time in seconds associated with the middle
    # of the ith window in an STFT using the parameters given above, rounded to
    # the nearest .01 seconds.
    pass


def transpose(x):
    # return the transpose of the input, which is given as a list of lists
    pass


def spectrogram(X, window_size, step_size, sample_rate):
    # X is the output of the stft function (a list of lists of DFT
    # coefficients) this function should return the spectrogram (magnitude
    # squared of the STFT).
    # it should be a list that is indexed first by k and then by i, so that
    # output[k][i] represents frequency bin k in analysis window i.
    pass


def plot_spectrogram(sgram, window_size, step_size, sample_rate):
    # the code below will uses matplotlib to display a spectrogram.  it uses
    # your k_to_hz and timestep_to_seconds functions to label the horizontal
    # and vertical axes of the plot.
    # amplitudes are plotted on a log scale, since human perception of loudness
    # is roughly logarithmic.
    width = len(sgram[0])
    height = len(sgram)//2+1  # only plot values up to N/2

    plt.imshow([[log(i + sys.float_info.min) for i in j] for j in sgram[:height+1]], aspect=width/height)
    plt.axis([0, width-1, 0, height-1])

    ticks = ticker.FuncFormatter(lambda x, pos: '{0:.1f}'.format(timestep_to_seconds(x, window_size, step_size, sample_rate)))
    plt.gca().xaxis.set_major_formatter(ticks)
    ticks = ticker.FuncFormatter(lambda y, pos: '{0:.0f}'.format(k_to_hz(y, window_size, step_size, sample_rate)))
    plt.gca().yaxis.set_major_formatter(ticks)

    plt.xlabel('time [s]')
    plt.ylabel('frequency [Hz]')

    plt.colorbar()
    plt.show()
