require 'nn'
require 'model2'
require 'groundTruth'
require 'cutorch'
require 'cunn'

require 'math'

local input_dim=100
local model=get_model2(input_dim)
local gt=get_GT(input_dim)

local cri=nn.MSECriterion():cuda()

local batch=200
local l1=model:get(1)
local gt_l1=gt:get(1):get(1)
print(model)
print('linear layer=',l1)

local input

require 'distributions'
local mean=torch.Tensor(input_dim):zero()
local cov=torch.eye(input_dim)
gt_l1.weight=torch.load("ground_truth")
print('gt_l1 norm=',torch.norm(gt_l1.weight))
model:get(1).weight:div(1.2)
print('l1 norm=',torch.norm(model:get(1).weight))

local train_size=100000
local test_size=10000
local train_book=torch.Tensor(train_size,input_dim):zero()
local test_book=torch.Tensor(test_size,input_dim):zero()
distributions.mvn.rnd(train_book,mean,cov)
distributions.mvn.rnd(test_book,mean,cov)
train_book:cuda()
test_book:cuda()

local input=torch.CudaTensor(batch,input_dim):zero()
local real_ans=torch.CudaTensor(batch,1):zero()

local grand=torch.load('rlt_2')
local rlt={}
grand[#grand+1]=rlt
local function potential()
   local sum=0
   local function get_ei(w,i)
      local t=w:clone()
      t[i]=t[i]+1
      return torch.norm(t)
   end
   for i=1,input_dim do
      sum=sum+get_ei(gt_l1.weight[i],i)-get_ei(l1.weight[i],i)
   end
   return sum
end

local stepSize=0.001
local eyeM=torch.eye(input_dim):cuda()
print("Eyr",torch.norm(eyeM))
for epoch=1,200 do
   local loc=torch.randperm(train_size)
   local iter=1
   local tot_train_loss=0
   rlt[epoch]={}
   for j=1,train_size/batch do
      for k=1,batch do
         input[k]:copy(train_book[loc[iter]])
         iter=iter+1
      end

      real_ans=gt:forward(input)

      cri:forward(model:forward(input),real_ans)
      model:zeroGradParameters()
      cri:backward(model.output,real_ans)
      model:backward(input,cri.gradInput)

      if (j==1) then
         rlt[epoch].potential=potential()
         rlt[epoch].inner=torch.dot(l1.gradWeight,l1.weight-gt_l1.weight)
      end

      l1.weight:add(-stepSize,l1.gradWeight)
      tot_train_loss=tot_train_loss+cri.output
   end
   print("tot train loss=",tot_train_loss/train_size*batch,torch.norm(l1.weight-gt_l1.weight-eyeM),
         torch.norm(l1.weight-eyeM))
   rlt[epoch].train_loss=tot_train_loss/train_size*batch
   rlt[epoch].dist_2_gt=torch.norm(l1.weight-gt_l1.weight-eyeM)
   rlt[epoch].norm=torch.norm(l1.weight-eyeM)
   iter=1
   local tot_loss=0
   for j=1,test_size/batch do
      for k=1,batch do
         input[k]:copy(test_book[iter])
         iter=iter+1
      end
      real_ans=gt:forward(input)
      cri:forward(model:forward(input),real_ans)
      tot_loss=tot_loss+cri.output
   end
   print("tot loss=\t\t\t\t\t\t\t\t\t\t",tot_loss/test_size*batch)
   rlt[epoch].test_loss=tot_loss/test_size*batch
   torch.save('rlt_2',grand)
end
