from __future__ import print_function
from PIL import Image
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
datafolder='/home/yang/data/data/MNIST'
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(datafolder, train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(datafolder, train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=64, shuffle=True)

saved=[]
for i in range(0,10):
    saved.append([])
for batch_idx, (data, target) in enumerate(train_loader):
    for d,t in zip(data,target):
        saved[t].append(d)

pair=[0,1]
inp=Variable(torch.Tensor(28*28).cuda())
I=Variable(torch.eye(28*28).cuda())
relu=nn.ReLU()
s=[]


def getScore(W, pair):
    for i in range(0, 2):
        s.append([])
        p = pair[i]
        for img in saved[p]:
            inp.data.copy_(img)
            linear_output = torch.mv(W, inp)
            relu_output = relu(linear_output)
            tot = torch.sum(relu_output)
            s[i].append(tot.data[0])

            # print(relu_output.view(28,28))
            # print(inp.view(28,28))
            # x=inp.data.view(28,28).cpu().numpy()
            # plt.matshow(x,cmap = plt.cm.gray)
            # plt.show()
    s[0].sort()
    s[1].sort()
    p = [0, 0]
    best = 0.0
    for i in range(0, len(s[0]) + len(s[1])):
        if (p[1] >= len(s[1])) or (p[0] < len(s[0]) and s[0][p[0]] < s[1][p[1]]):
            p[0] += 1
        else:
            p[1] += 1
        # Assume we set current value as below is class 0, and above is class 1
        if best < p[0] + len(s[1]) - p[1]:
            best = p[0] + len(s[1]) - p[1]
            # print("best..",p[0],p[1], best)
        if best < p[1] + len(s[0]) - p[0]:
            best = p[1] + len(s[0]) - p[0]
            # print("best.",p[0],p[1],best)
    # print("best=", float(best)/(len(s[0])+len(s[1])))
    return float(best) / (len(s[0]) + len(s[1]))


# Now want to train this.
class simple(nn.Module):
    def __init__(self):
        super(simple, self).__init__()
        self.l1 = nn.Linear(28 * 28, 28 * 28, bias=False)
        self.l1.weight.data.fill_(0)
        # self.c=nn.Linear(1,1,bias=False)
        # self.c.weight.data.fill_(0.001)
        self.Relu = nn.ReLU()

    def forward(self, inp):
        output = self.l1.forward(inp)
        output = output + inp
        output = self.Relu(output)
        output = torch.sum(output, 1)
        output = output / 350
        return output.view(-1)


def train_net(pair):
    eta = 0.001
    batch = 100
    halfb = batch // 2

    inp = Variable(torch.Tensor(batch, 28 * 28).cuda())
    target = Variable(torch.Tensor(batch).cuda())
    net = simple().cuda()
    optimizer = optim.SGD(net.parameters(), lr=eta, weight_decay=0.01)
    target.data[:halfb].fill_(0)
    target.data[halfb:].fill_(1)
    criterion = nn.MSELoss()
    rlt = [0, 0]
    for i in range(0, 200):
        r0 = torch.randperm(len(saved[pair[0]])).type_as(torch.IntTensor())
        r1 = torch.randperm(len(saved[pair[1]])).type_as(torch.IntTensor())
        p = 0
        totE = 0
        miss = 0
        N = 0
        for j in range(0, len(saved[pair[0]]) // halfb):
            if (p + halfb > len(saved[pair[0]])) or (p + halfb > len(saved[pair[1]])):
                break;
            for k in range(0, halfb):
                inp.data[k].copy_(saved[pair[0]][r0[p + k]])
                inp.data[halfb + k].copy_(saved[pair[1]][r1[p + k]])
            net.zero_grad()
            output = net(inp)
            err = criterion(output, target)
            for k in range(0, batch):
                if ((target.data[k] > 0.5) and (output.data[k] < 0.5)) or (
                    (target.data[k] < 0.5) and (output.data[k] > 0.5)):
                    miss += 1

            totE += err.data[0]
            N = N + batch

            if i > 0 and i < 199:
                err.backward()
                optimizer.step()
            # print(net.l1.weight.grad.data)
            # net.c.weight.data.fill_(1)
            p += halfb
        U, s, V = torch.svd(net.l1.weight.data)
        if i == 0:
            rlt[0] = miss / N
            if miss / N > 0.5:
                return rlt
        if i % 10 == 0:
            print(i, "err=", totE / N * batch, s[0], miss / N, miss, N)
        if s[0] > 0.6:
            net.l1.weight.data.div_(s[0] / 0.6)
    rlt[1] = miss / N
    return rlt
    # assert(False)


# print("naive",getScore(I,[0,1]))
full_ans = []
for i in range(0, 10):
    for j in range(0, 10):
        if i != j:
            print(i, j)
            rlt = train_net([i, j])
            full_ans.append([i, j, rlt[0], rlt[1]])
            print(full_ans)

tot=0
vanilla=0
trained=0
for entry in full_ans:
    #print(entry)
    if entry[3]>0: #valid
        tot+=1
        vanilla+=entry[2]
        trained+=entry[3]

print(tot,vanilla/tot*100,trained/tot*100)