
import sys, math
import parse

def add2(a,b): return a+b

class confusionMatrix():
    def __init__(self, ci):
        self.mat={}
        self.m0_values=set(ci.methods[0].values())
        self.m1_values=set(ci.methods[1].values())
        # construct matrix and initialize to 0
        for v0 in self.m0_values:
            self.mat[v0]={}
            for v1 in self.m1_values:
                self.mat[v0][v1]=0

        # fill in matrix by considering each sample and tallying the right cell.
        for s in ci.samples:
            v0=ci.lookup(0,s)
            v1=ci.lookup(1,s)
            self.mat[v0][v1]+=1
        # compute row sums and col sums
        self.row_sums=[reduce(add2, [self.mat[i][j] for j in self.m1_values]) for i in self.m0_values]
        self.col_sums=[reduce(add2, [self.mat[i][j] for i in self.m0_values]) for j in self.m1_values]

def Ha(ci, cm):
    sum=0.0
    n=float(len(ci.samples))
    for rs in cm.row_sums:
        try:
            tmp=rs/n*math.log(rs/n, 2.0)
        except OverflowError:
            tmp=0.0
        sum+=tmp
    return -sum

def Hb(ci, cm):
    sum=0.0
    n=float(len(ci.samples))
    for rs in cm.col_sums:
        try:
            tmp=rs/n*math.log(rs/n, 2.0)
        except OverflowError:
            tmp=0.0
        sum+=tmp
    return -sum

def Hab(ci, cm):
    sum=0.0
    n=float(len(ci.samples))
    for i in cm.m0_values:
        for j in cm.m1_values:
            try:
                term1=cm.mat[i][j]/n
                tmp=-1.0*term1*math.log(term1, 2.0)
            except OverflowError:
                tmp=0.0
            sum+=tmp
    return sum

def nmi(ci):
    cm=confusionMatrix(ci)
    ha=Ha(ci,cm)
    hb=Hb(ci,cm)
    hab=Hab(ci,cm)
    r=1-(hab-hb)/ha
    return r    
    
if __name__=='__main__':
    clfile=sys.argv[1]
    bootstraps=int(sys.argv[2])
    hasheaders=False
    
    ci=parse.ClusterInfo()
    ci.parseFile(clfile, hasheaders)
    v=nmi(ci)

    cnt=0
    for i in xrange(bootstraps):
        cip=ci.permute()
        vp=nmi(cip)
        if vp>v: cnt+=1
    print v, float(cnt)/bootstraps
    
