import torch
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import imageio


DX = 0.1
angle = np.pi
R = 2
r = 0.1

def func1(data, initial_shape):
    ret = 0
    ret2 = 0
    cnt = 0
    for i in range (len(data)):
        for j in range (i+1, len(data)):
            gt = np.linalg.norm(initial_shape[i,:] - initial_shape[j,:])
            test = np.linalg.norm(data[i,:] - data[j,:])
            # print(gt - test)
            ret += np.abs(gt - test) 
            ret2 += (gt - test) ** 2 
            cnt += 1
    return ret/cnt, ret2/cnt

def func2(data, initial_shape, DX = 0.1):
    ret = 0
    ret2 = 0
    cnt = 0
    for i in range (len(data) - 1):
        gt = DX
        test = np.linalg.norm(data[i,:] - data[i+1,:])
        ret += np.abs(gt - test) 
        ret2 += (gt - test) ** 2 
        cnt += 1
    return ret/cnt, ret2/cnt

def func3(data, initial_shape):
    ret = 0
    ret2 = 0
    cnt = 0
    for i in range (len(data) - 2):
        gt = angle
        t1 = data[i,:] - data[i+1,:]
        t2 = data[i+2,:] - data[i+1,:]
        # print(np.dot(t1, t2) / (np.linalg.norm(t1) * np.linalg.norm(t2)))
        test = np.arccos(np.dot(t1, t2) / (np.linalg.norm(t1) * np.linalg.norm(t2)))
        # print(test)
        ret += np.abs(gt - test) 
        ret2 += (gt - test) ** 2 
        cnt += 1
    return ret/cnt, ret2/cnt


def func4(data, initial_shape):
    ret = 0
    ret2 = 0
    cnt = 0
    for i in range (len(data)):
        gt = R - r
        test = np.linalg.norm(data[i,:])
        temp = test - gt
        if temp < 0:
            temp = 0
        ret += np.abs(temp) 
        ret2 += (temp) ** 2 
        cnt += 1
    return ret/cnt, ret2/cnt
    
def func5(data, initial_shape):
    ret = 0
    ret2 = 0
    cnt = 0
    for i in range(4):
        for j in range(4, 8):
            gt = r * 2
            test = np.linalg.norm(data[i,:] - data[j, :])
            temp = test - gt
            if temp > 0:
                temp = 0
            ret += np.abs(temp) 
            ret2 += (temp) ** 2 
            cnt += 1
    for i in range (len(data)):
        gt = R - r
        test = np.linalg.norm(data[i,:])
        temp = test - gt
        if temp < 0:
            temp = 0
        ret += np.abs(temp) 
        ret2 += (temp) ** 2 
        cnt += 1
    print(cnt)
    return ret/cnt, ret2/cnt



functions = [func1, func2, func3, func4, func5]

def run(proj_list, constraint_list, initial_shape, name):
    for c in range(len(constraint_list)):
        if constraint_list[c] == 1:
            file1 = open("results/" + name + '_' + str(c) + ".txt", "w") 
            f = functions[c]
            avg1 = 0; avg2 = 0
            for i in range(len(proj_list)):
                a1, a2 = f(proj_list[i], initial_shape)
                avg1 += a1; avg2 += a2
                file1.write(str(i) + '\t' + str(a1) + '\t' + str(a2) + '\n')
            file1.write('avg' + '\t' + str(avg1/100) + '\t' + str(avg2/100) + '\n')
            file1.close() 


def run_a(proj_list, constraint_list, initial_shape, name):
    for c in range(len(constraint_list)):
        if constraint_list[c] == 1:
            file1 = open("results/" + name + '_' + str(c) + ".txt", "w") 
            f = functions[c]
            avg1 = 0; avg2 = 0
            for i in range(len(proj_list)):
                a1, a2 = f(proj_list[i][6:10], initial_shape)
                avg1 += a1; avg2 += a2
                file1.write(str(i) + '\t' + str(a1) + '\t' + str(a2) + '\n')
            file1.write('avg' + '\t' + str(avg1/100) + '\t' + str(avg2/100) + '\n')
            file1.close() 

def run_b(proj_list, constraint_list, initial_shape, name, DX):
    for c in range(len(constraint_list)):
        if constraint_list[c] == 1:
            file1 = open("results/" + name + '_' + str(c) + ".txt", "w") 
            f = functions[c]
            avg1 = 0; avg2 = 0
            for i in range(len(proj_list)):
                if c == 1:
                    a1, a2 = f(proj_list[i][0:6], initial_shape, DX=0.2)
                else:
                    a1, a2 = f(proj_list[i][0:6], initial_shape)
                avg1 += a1; avg2 += a2
                file1.write(str(i) + '\t' + str(a1) + '\t' + str(a2) + '\n')
            file1.write('avg' + '\t' + str(avg1/100) + '\t' + str(avg2/100) + '\n')
            file1.close() 


def run_c(proj_list, constraint_list, initial_shape, name):
    for c in range(len(constraint_list)):
        if constraint_list[c] == 1:
            file1 = open("results/" + name + '_' + str(c) + ".txt", "w") 
            f = functions[c]
            avg1 = 0; avg2 = 0
            for i in range(len(proj_list)):
                a1 = 0; a2 = 0
                a11, a22 = f(proj_list[i][7:11, :], initial_shape); a1 += a11; a2 += a22
                a11, a22 = f(proj_list[i][17:21, :], initial_shape); a1 += a11; a2 += a22
                avg1 += a1/2; avg2 += a2/2
                file1.write(str(i) + '\t' + str(a1) + '\t' + str(a2) + '\n')
            file1.write('avg' + '\t' + str(avg1/100) + '\t' + str(avg2/100) + '\n')
            file1.close() 

def run_d(proj_list, constraint_list, initial_shape, name):
    for c in range(len(constraint_list)):
        if constraint_list[c] == 1:
            file1 = open("results/" + name + '_' + str(c) + ".txt", "w") 
            f = functions[c]
            avg1 = 0; avg2 = 0
            for i in range(len(proj_list)):
                a1 = 0; a2 = 0
                a11, a22 = f(proj_list[i][0:8, :], initial_shape); a1 += a11; a2 += a22
                a11, a22 = f(proj_list[i][10:18, :], initial_shape); a1 += a11; a2 += a22
                a11, a22 = f(proj_list[i][20:28, :], initial_shape); a1 += a11; a2 += a22
                avg1 += a1/3; avg2 += a2/3
                file1.write(str(i) + '\t' + str(a1) + '\t' + str(a2) + '\n')
            file1.write('avg' + '\t' + str(avg1/100) + '\t' + str(avg2/100) + '\n')
            file1.close() 

def run_e(proj_list, constraint_list, initial_shape, name):
    for c in range(len(constraint_list)):
        if constraint_list[c] == 1:
            file1 = open("results/" + name + '_' + str(c) + ".txt", "w") 
            f = functions[c]
            avg1 = 0; avg2 = 0
            for i in range(len(proj_list)):
                a1 = 0; a2 = 0
                a11, a22 = f(proj_list[i][0:4, :], initial_shape); a1 += a11; a2 += a22
                a11, a22 = f(proj_list[i][4:8, :], initial_shape); a1 += a11; a2 += a22
                avg1 += a1/2; avg2 += a2/2
                file1.write(str(i) + '\t' + str(a1) + '\t' + str(a2) + '\n')
            file1.write('avg' + '\t' + str(avg1/100) + '\t' + str(avg2/100) + '\n')
            file1.close() 

def run_f(proj_list, constraint_list, initial_shape, name):
    print(len(proj_list))
    for c in range(len(constraint_list)):
        if constraint_list[c] == 1:
            file1 = open("results/" + name + '_' + str(c) + ".txt", "w") 
            f = functions[c]
            avg1 = 0; avg2 = 0
            for i in range(len(proj_list)):
                a1 = 0; a2 = 0
                a11, a22 = f(proj_list[i][0:4, :], initial_shape); a1 += a11; a2 += a22
                a11, a22 = f(proj_list[i][4:8, :], initial_shape); a1 += a11; a2 += a22
                a11, a22 = f(proj_list[i][8:12, :], initial_shape); a1 += a11; a2 += a22
                a11, a22 = f(proj_list[i][12:16, :], initial_shape); a1 += a11; a2 += a22
                avg1 += a1/24; avg2 += a2/4
                file1.write(str(i) + '\t' + str(a1) + '\t' + str(a2) + '\n')
            file1.write('avg' + '\t' + str(avg1/100) + '\t' + str(avg2/100) + '\n')
            file1.close() 


def read_gt(t, T):
    root = 'D:/ysq/codes/complex/bin/win/physics_dataset/Release/data/'
    if t == 1:
        path = 'gt1_'
        num = 4
    elif t == 2:
        path = 'gt2_'
        num = 16
    elif t == 4:
        path = 'gt3_'
        num = 8
    elif t == 6:
        path = 'gt6_'
        num = 10
    elif t == 7:
        path = 'gt4_'
        num = 20
    elif t == 8:
        path = 'gt8_'
        num = 28
    elif t == 9:
        path = 'gt9_'
        num = 32
    else:
        print("XXXXXX")
    # T = 50
    label_path = root + path + '/data_' + path + '_0.txt' 
    ret  = []
    data_f = open(label_path, "r") 
    for t in range(T+1):
        data = np.ones([num, 2])
        for j in range (num):
            for k in range (2):
                data[j, k] = float(data_f.readline())
        ret.append(data)
    return ret

def cal_mse(data, t, name):
    T = 50
    label = read_gt(t, T)
    file1 = open("results/" + name + '_' + "mse.txt", "w") 
    print(data[0]); print(label[0]); # print(data[1]); print(label[1])
    for i in range(1, T+1):
        l = label[i]
        d = data[i]
        lo = np.average(np.average((d - l)**2))
        print(lo)
        file1.write(str(i) + '\t' + str(lo) + '\n')
    file1.close() 
