import numpy
from matplotlib import pyplot
from scipy.stats import ttest_rel
from pygazeanalyser.edfreader import read_edf

# read data file
data = read_edf('ED_pupil.asc', 'PUPIL_TRIALSTART', \
    stop='pupdata_stop')

# create a new dict to contain traces
traces = {'black':[], 'white':[]}

# loop through all trials
n_trials = len(data)
for i in range(n_trials):
    
    # check the trial type
    t0, msg = data[i]['events']['msg'][0]
    if 'black' in msg:
        trialtype = 'black'
    elif 'white' in msg:
        trialtype = 'white'
    
    # get the timestamps of baseline and monitor change
    t1, msg = data[i]['events']['msg'][1]
    t2, msg = data[i]['events']['msg'][2]
    # turn the timestamps into index numbers
    t1i = numpy.where(data[i]['trackertime'] == t1)[0]
    t2i = numpy.where(data[i]['trackertime'] == t2)[0]
    
    # get the baseline trace
    baseline = data[i]['size'][t1i:t2i]
    # get the pupil change trace (2000 samples)
    trace = data[i]['size'][t2i:t2i+2000]
    
    # divide the pupil trace by the baseline median
    trace = trace / numpy.median(baseline)
    
    # add the trace to the list for this trial type
    traces[trialtype].append(trace)
    
# convert lists to NumPy arrays
traces['black'] = numpy.array(traces['black'])
traces['white'] = numpy.array(traces['white'])

# create an empty dict to contain mean and SEM
avgs = {'black':{}, 'white':{}}
# loop through both conditions
for con in ['black', 'white']:
    # calculate the number of trials in this condition
    n_trials = len(traces[con])
    # calculate the average trace in this condition
    avgs[con]['M'] = numpy.mean(traces[con], axis=0)
    # calculate the standard deviation in this condition
    sd = numpy.std(traces[con], axis=0)
    # calculate the standard error in this condition
    avgs[con]['SEM'] = sd / numpy.sqrt(n_trials)

# do a t-test on every timepoint
t, p = ttest_rel(traces['black'], traces['white'], axis=0)

# Bonferroni-corrected alpha
alpha = 0.05 / len(t)

# define the plotting colours
cols = {'black':'#204a87', 'white':'#c4a000'}
# create a new figure with a single axis
fig, ax = pyplot.subplots(figsize=(19.2,10.8), dpi=100.0)

# loop through the conditions
for con in ['black', 'white']:
    # create x-values
    x = range(len(avgs[con]['M']))
    # plot the mean trace
    ax.plot(x, avgs[con]['M'], '-', color=cols[con], \
        label=con)
    # plot the standard error of the mean shading
    y1 = avgs[con]['M'] + avgs[con]['SEM']
    y2 = avgs[con]['M'] - avgs[con]['SEM']
    ax.fill_between(x, y1, y2, color=cols[con], alpha=0.3)

# create y arrays
y1 = numpy.zeros(len(x))
y2 = numpy.ones(len(x)) * 2
# shade significant difference between traces
ax.fill_between(x, y1, y2, where=p<alpha, \
    color='#babdb6', alpha=0.2)

# set axes limits
ax.set_xlim([0, 2000])
ax.set_ylim([0, 2])
# set axis labels
ax.set_xlabel('time (ms)')
ax.set_ylabel('proportional pupil size change')
# add legend
ax.legend(loc='upper left')

fig.savefig('pupil_traces.png')