require 'nn'
require 'groundTruth'
require 'model'
require 'cutorch'
require 'cunn'

require 'math'
local input_dim=1500

local model=get_model(input_dim)
local GT=get_GT(input_dim)
model:cuda()
GT:cuda()

local cri=nn.MSECriterion():cuda()

local batch=300
local l1=model:get(1):get(1)
local GT_l1=GT:get(1):get(1)

local input
l1.weight=torch.load("save.t7")

local function get_2_norm(M)
   local tmpM=M:clone():float()
   local eigs=torch.eig(tmpM)
   return torch.norm(eigs[1])
end
print("2 norm", get_2_norm(l1.weight),torch.norm(l1.weight))

GT_l1.weight:zero()

local function potential()
   local sum=0
   local function get_ei(w,i)
      local t=w:clone()
      t[i]=t[i]+1
      return torch.norm(t)
   end
   for i=1,input_dim do
      sum=sum+get_ei(GT_l1.weight[i],i)-get_ei(l1.weight[i],i)
   end
   return sum
end

require 'distributions'
local mean=torch.Tensor(input_dim):zero()
local cov=torch.eye(input_dim)
local times=10000000
local input=torch.Tensor(batch,input_dim):zero()
local cudaI=input:clone():cuda()
model:cuda()
local rlt={}
for j=1,times do

   distributions.mvn.rnd(input,mean,cov)
   --input=torch.Tensor(batch,input_dim):uniform():add(-0.5)
   --for k=1,batch do input[k]:div(torch.norm(input[k])) end
   cudaI:copy(input)

   local real_ans=GT:forward(cudaI)
   local ttt=model:forward(cudaI)
   cri:forward(model:forward(cudaI),real_ans)
   model:zeroGradParameters()
   cri:backward(model.output,real_ans)
   model:backward(cudaI,cri.gradInput)

   if (j%100==0) then
      print('\t\t inner=',torch.dot(l1.gradWeight,l1.weight-GT_l1.weight), 'l1 norm= ', torch.norm(l1.weight), 'dist=',
         torch.norm(l1.weight-GT_l1.weight),'potential=', potential(), 'loss=',cri.output)
      torch.save('rlt',rlt)
   end
   rlt[j]={
      inner=torch.dot(l1.gradWeight,l1.weight-GT_l1.weight),
      dist=torch.norm(l1.weight-GT_l1.weight),
      potential=potential(),
      loss=cri.output,
   }
   l1.weight:add(-0.0001,l1.gradWeight)
end
