import datetime
import math
#import seaborn as sns
import simpy
import random
import numpy as np
import sys
import matplotlib
matplotlib.use('TkAgg')  # Adjust based on your OS and environment
import matplotlib.pyplot as plt
import pandas as pd
from real_dataset.parse_datasets import load_and_prepare_data


response_times = []
response_pshort_times = []
response_plong_times = []
short_queue = []
long_queue = []
#response_plong_times_details = []

start_time = None
job_processes = {}
n_cheap_p = 0
n_expensive_p = 0
#SIMULATION_TIME = 1000000 #TODO
SIMULATION_TIME = 100000
#T= float('inf') #to test FCFS
#T= 0 #to test SPRPT
#T = 4
current_job = None
LOG_EVENT_PRINT = 0
PLOT_GRAPHS = 0

#predictor = 'expP'
#predictor = 'perfectP'
#predictor = 'uniP'

#dist = 'weibull'
#dist = 'exponential'

#Real dataset
dist = 'real'
predictor = 'real'

real_data_index = 0

cheap_alpha = 0.8
expensive_alpha = 0.2


#TEST COST ACC
use_seperate_cheapp = 0
cheap_acc = 0.8

alpha_cost_mapping = {
    0: 2,    # When alpha is 0 (perfect predictor), cost is 2
    0.1: 1.85,
    0.3: 1.55,
    0.5: 1.25,
    0.8: 0.75,
    1: 0.5
}


class EndOfDataException(Exception):
    pass

def log_event(time, event, job_id="--", size="--", predicted_size="--", notes="--", queue_content="--"):
    if LOG_EVENT_PRINT:
        time_str = f"{time:.5f}"
        job_id = "--" if job_id == "--" else f"{job_id:.2f}"
        size = "--" if size == "--" else f"{size:.2f}"
        predicted_size = "--" if predicted_size == "--" else f"{predicted_size:.2f}"
        notes = "--" if notes == "--" else "DONE" if notes == "DONE" else f"{notes:.2f}"
        print(f"{time_str:<10}| {event:<20}| {job_id:<10}| {size:<10}| {predicted_size:<15}| {notes:<20}| {queue_content}")



def short_job_size_distribution(): #TODO: change
    global T
    while True:
        sample = job_size_distribution()
        if sample < T:
            return sample

#def job_size_distribution():
#    return random.expovariate(1)


#def job_size_distribution():
#    U = random.random()
#    return (-math.log(1 - U))**2 / 2

def job_size_distribution():
    global real_data_index

    if dist == 'exponential':
        return random.expovariate(1), 0

    if dist == 'weibull':
        U = random.random()
        return (-math.log(1 - U))**2 / 2, 0

    elif dist == 'real':
        # Return job size from real dataset, and handle index increment
        if real_data_index < len(real_data): #TODO
            job_size = real_data.loc[real_data_index, 'normalized_runtime']
            predicted_job_size = real_data.loc[real_data_index, 'normalized_predicted_runtime']

            real_data_index += 1
            return job_size, predicted_job_size
        else:
            raise EndOfDataException("End of data reached")

    else:
        raise ValueError("Unknown distribution type specified.")


def predict_service_time(job, uni_alpha):
    if predictor == 'perfectP':
        return job

    if predictor == 'expP':
        return random.expovariate(1 / job)

    if predictor == 'uniP':
        lower_bound = (1 - uni_alpha) * job
        upper_bound = (1 + uni_alpha) * job
        return random.uniform(lower_bound, upper_bound)

    else:
        raise ValueError("Unknown distribution type specified.")


# def is_service_time_less_than_T(true_service_time, T, accuracy_prob, max_deviation):
#     """
#     Predict if the service time is less than T.
#
#     Parameters:
#     - true_service_time: The actual service time.
#     - T: The threshold service time.
#     - accuracy_prob: The desired probability of accurate prediction.
#     - max_deviation: The maximum allowed deviation for the prediction to be considered accurate.
#
#     Returns:
#     - prediction: True if predicted service time is less than T, False otherwise.
#     - is_accurate: True if the prediction is within the desired accuracy range, False otherwise.
#     - deviation: The deviation of the prediction from the true service time.
#     """
#     # Determine if the prediction should be accurate
#     if random.random() <= accuracy_prob:
#         predicted_service_time = true_service_time
#     else:
#         # Introduce a random deviation within the allowed maximum deviation
#         deviation = random.uniform(-max_deviation, max_deviation)
#         predicted_service_time = true_service_time + deviation
#
#     # Calculate the deviation from the true service time
#     deviation = abs(predicted_service_time - true_service_time)
#
#     # Determine if the prediction is accurate
#     is_accurate = deviation <= max_deviation
#
#     # Make the prediction
#     prediction = predicted_service_time < T
#
#     return prediction, is_accurate, deviation

def is_service_time_less_than_T(true_service_time, T, accuracy_prob):
    # Determine the actual condition
    actual_condition = true_service_time < T

    # Determine if the prediction should be accurate
    if random.random() <= accuracy_prob:
        prediction = actual_condition
        is_accurate = True
    else:
        prediction = not actual_condition
        is_accurate = False

    return prediction, is_accurate


############### Different types of predictors  #################

#def predict_service_time(job): #TODO: perfect predictor
#    return job

#def predict_service_time(job): #TODO uniP predictor
#    lower_bound = (1 - alpha) * job
#    upper_bound = (1 + alpha) * job
#    return random.uniform(lower_bound, upper_bound)

#def predict_service_time(z): #TODO: exponential predictor
#     return random.expovariate(1/z)

#def predict_uniP_cheap(job): #TODO cheap uniP predictor
#    global cheap_alpha
#    lower_bound = (1 - cheap_alpha) * job
#    upper_bound = (1 + cheap_alpha) * job
#    return random.uniform(lower_bound, upper_bound)

#def predict_uniP_expensive(job):  # TODO expensive uniP predictor
#    global expensive_alpha
#    lower_bound = (1 - expensive_alpha) * job
#    upper_bound = (1 + expensive_alpha) * job
#    return random.uniform(lower_bound, upper_bound)
#####################################################

def remaining_time(job):
    age = job[2]
    return job[1] - age

def check_schedule(env, server):
    global current_job, start_time

    short_queue_str = [f"({x:.2f}, {y:.2f})" for x, y in short_queue]
    long_queue_str = [f"({x:.2f}, {y:.2f}, {z:.2f})" for x, y, z, w in long_queue]
    queue_str = f"Short: {short_queue_str}, Long: {long_queue_str}"
    log_event(env.now, "Checking Policy", "--", "--", "--", "--", queue_str)

    if short_queue:
        log_event(env.now, "Schedule Short Job", "--", "--", "--", "--", "--")
        env.process(serve_job(env, server, preemptive=False))
    elif long_queue:
        log_event(env.now, "Schedule long Job", "--", "--", "--", "--", "--")
        env.process(serve_job(env, server, preemptive=True))

def serve_job(env, server, preemptive):
    global current_job, start_time

    if preemptive:  # long job
        long_queue.sort(key=remaining_time)

        if short_queue:
            check_schedule(env, server)
            return

        if current_job:
            age_received = env.now - start_time
            if (current_job[1] - (current_job[2] + age_received)) > remaining_time(long_queue[0]): #remaining time current > remaining time of new
                updated_job = (current_job[0], current_job[1], current_job[2] + age_received, current_job[3])
                log_event(env.now, "Long Preempted Long", updated_job[0], updated_job[1], "--", updated_job[0], f"Age: {updated_job[2]}")

                long_queue.append(updated_job)
                long_queue.sort(key=remaining_time)
                current_process = job_processes.get(start_time)
                if current_process:
                    current_process.interrupt()


        with server.request(priority=1) as req:
            yield req

            if short_queue:
                check_schedule(env, server)
                return

            if not long_queue:
                return
            next_job = long_queue.pop(0)
            log_event(env.now, "Serve Long", next_job[0], next_job[1], "--", next_job[0], f"Age: {next_job[2]}")
            start_time = env.now
            current_job = (next_job[0], next_job[1], next_job[2], next_job[3])
            current_process = env.process(serve_long_job(env, next_job))
            job_processes[start_time] = current_process
            try:
                yield current_process

                response_time = env.now - current_job[0]
                log_event(env.now, "Long Job Done", current_job[0], "--", "--", "DONE", f"L-response:{response_time}")

                job_details = {
                    'response_time': response_time,
                    'actual_size': current_job[3],
                    'predicted_size': current_job[1]
                }
                #response_plong_times_details.append(job_details)
                response_times.append(response_time)
                #print(f'response_plong_times.append(response_time): {response_time}')
                response_plong_times.append(response_time)
                current_job = None
                if start_time in job_processes:
                    del job_processes[start_time]
                start_time = None

            except simpy.Interrupt as interrupt:
                if str(interrupt.cause) == "Short Job Arrival":
                    elapsed_time = env.now - start_time
                    updated_job = (current_job[0], current_job[1], current_job[2] + elapsed_time, current_job[3])
                    log_event(env.now, "***********Short Preempted Long", updated_job[0], updated_job[1], "--", updated_job[0], f"Age: {updated_job[2]}")
                    if start_time in job_processes:
                        del job_processes[start_time]

                    long_queue.append(updated_job)
                    long_queue.sort(key=remaining_time)
                    current_job = None
                    check_schedule(env, server)



    else:  # short job
        current_process = job_processes.get(start_time)
        if current_process and current_process.is_alive:
            current_process.interrupt("Short Job Arrival")

        with server.request(priority=0) as req:
            yield req
            if not short_queue:
                return
            next_job = short_queue.pop(0)
            log_event(env.now, "Serve Short", next_job[0], next_job[1], "--", next_job[0], "--")
            yield env.timeout(next_job[1])
            log_event(env.now, "Short Job Done", next_job[0], "--", "--", "DONE", "--")
            response_times.append(env.now - next_job[0])
            response_pshort_times.append(env.now - next_job[0])

    current_job = None
    check_schedule(env, server)


def serve_long_job(env, job):
    actual_remaining_time = job[3] - job[2]
    yield env.timeout(actual_remaining_time)


def log_job_completion(env, job):
    log_event(env.now, "Long Job Done", job_id=job[0])
    response_time = env.now - job[0]
    job_details = {
        'response_time': response_time,
        'actual_size': job[3],
        'predicted_size': job[1]
    }
    #response_plong_times_details.append(job_details)
    response_times.append(response_time)
    response_plong_times.append(response_time)
    global current_job, start_time
    current_job = None
    if start_time in job_processes:
        del job_processes[start_time]
    start_time = None


def job_generator(env, server, arriving_rate):
    global n_cheap_p, n_expensive_p
    global T
    global cheap_alpha, expensive_alpha, cheap_acc
    while True:
        yield env.timeout(random.expovariate(arriving_rate))
        job = env.now
        try:
            job_size, cheap_predicted_service_time = job_size_distribution()
        except EndOfDataException:
            print("End of data reached. Terminating simulation.")
            break

        if dist == 'exponential' or dist == 'weibull':
            cheap_predicted_service_time = predict_service_time(job_size, cheap_alpha)
        n_cheap_p += 1
        log_event(env.now, "Job Arrival", job, job_size, cheap_predicted_service_time, "--", f"Threshold: {T:.3f}")

        if use_seperate_cheapp:
            prediction, truth = is_service_time_less_than_T(job_size, T, cheap_acc)
            is_short = prediction
        else:
            is_short = cheap_predicted_service_time < T

        if is_short:
            short_queue.append((job,job_size))
            log_event(env.now, "Append to Short", job, "--", "--", "--", "--")

        else:
            n_expensive_p += 1
            if dist == 'exponential' or dist == 'weibull':
                predicted_service_time = predict_service_time(job_size, expensive_alpha)
            else:
                predicted_service_time = cheap_predicted_service_time

            long_queue.append((job, predicted_service_time, 0, job_size))
            log_event(env.now, "Append to Long", job, "--", predicted_service_time, "--", "--")

        check_schedule(env, server)


def dummy_job_generator(env, server):
   global n_cheap_p, n_expensive_p

   # First job - Short
   yield env.timeout(1)
   job = env.now
   job_size = short_job_size_distribution()
   predicted_service_time = predict_service_time(job_size)
   n_cheap_p += 1
   log_event(env.now, "Job Arrival", job, job_size, predicted_service_time, "--", f"Threshold: {T:.3f}")
   short_queue.append((job, job_size))
   log_event(env.now, "Append to Short", job, "--", "--", "--", "--")
   check_schedule(env, server)

   # Second job - Long
   yield env.timeout(2)
   job = env.now
   job_size = job_size_distribution()
   while job_size <= T:
       job_size = job_size_distribution()
   predicted_service_time = predict_service_time(job_size)
   n_cheap_p += 1
   n_expensive_p += 1
   log_event(env.now, "Job Arrival", job, job_size, predicted_service_time, "--", f"Threshold: {T:.3f}")
   long_queue.append((job, predicted_service_time, 0,job_size))
   log_event(env.now, "Append to Long", job, "--", predicted_service_time, "--", "--")
   check_schedule(env, server)

   yield env.timeout(job_size - 0.5)
   job = env.now
   job_size = short_job_size_distribution()
   predicted_service_time = predict_service_time(job_size)
   n_cheap_p += 1
   log_event(env.now, "Job Arrival", job, job_size, predicted_service_time, "--", f"Threshold: {T:.3f}")
   short_queue.append((job, job_size))
   log_event(env.now, "Append to Short", job, "--", "--", "--", "--")
   check_schedule(env, server)



def run_simulation(arrival_rate, threshold, c1, c2, test_cheap_alpha, test_expen_alpha):
    global response_times, response_pshort_times, response_plong_times, response_plong_times_details
    global start_time, job_processes, SIMULATION_TIME
    global T
    global use_seperate_cheapp
    global n_cheap_p, n_expensive_p
    global short_queue, long_queue
    global current_job
    global cheap_alpha, expensive_alpha, cheap_acc
    # Re-initialize the lists at the start of each run
    response_times = []
    response_pshort_times = []
    response_plong_times = []
    response_plong_times_details = []
    short_queue = []
    long_queue = []

    T = threshold
    n_cheap_p = 0
    n_expensive_p = 0

    if use_seperate_cheapp:
        cheap_acc = test_cheap_alpha
    else:
        cheap_alpha = test_cheap_alpha

    expensive_alpha = test_expen_alpha

    start_time = None
    job_processes = {}

    if threshold == 0: #SPRPT
        cheap_alpha = 0

    print(f'Simulating M/M/1 queue with server cost, lambda: {arrival_rate}, T:{T}')
    env = simpy.Environment()
    server = simpy.PriorityResource(env, capacity=1)
    current_job = None
    print(f"{'TIME':<10} | {'EVENT':<18} | {'JOB ID':<10} | {'SIZE':<8} | {'PREDICTED SIZE':<15} | {'SERVER':<18} | {'NOTES'}")
    print("-" * 145)
    env.process(job_generator(env, server, arrival_rate))

    env.run(until=SIMULATION_TIME)
    #print(f'response_times: {response_times}')
    mean_response_time_pshort = 0
    mean_response_time_plong = 0
    print(f'n_cheap_p: {n_cheap_p}')
    print(f'n_expensive_p: {n_expensive_p}')


    if threshold !=0 and threshold != float('inf'):
        mean_response_time_withoutcost = sum(response_times)/ len(response_times)
        mean_response_time_pshort_withoutcost = 0 if len(response_pshort_times) == 0 else (sum(response_pshort_times)/ len(response_pshort_times))
        mean_response_time_plong_withoutcost = 0 if len(response_plong_times) == 0 else (sum(response_plong_times)/ len(response_plong_times))

        print(f"\nMean Response Time without costs: {(sum(response_times)/ len(response_times)):.2f} time units")
        mean_response_time = 0 if len(response_times) == 0 else ((sum(response_times) + n_cheap_p * c1 + n_expensive_p * (c1 + c2)) / len(response_times))
        print(f"\nMean Response Time with costs: {mean_response_time:.2f} time units lambda: {arrival_rate}, T:{T}")

        print(f"\nPredicted short: Mean Response Time without costs: {mean_response_time_pshort_withoutcost:.2f} time units")
        mean_response_time_pshort = 0 if len(response_pshort_times) == 0 else ((sum(response_pshort_times) + n_cheap_p * c1) / len(response_pshort_times))
        print(f"\nPredicted short: Mean Response Time with costs: {mean_response_time_pshort:.2f} time units lambda: {arrival_rate}, T:{T}")

        print(f"\nPredicted long: Mean Response Time without costs: {mean_response_time_plong_withoutcost:.2f} time units")
        mean_response_time_plong = 0 if len(response_plong_times) == 0 else ((sum(response_plong_times) + n_expensive_p * (c1 + c2)) / len(response_plong_times))
        print(f"\nPredicted long: Mean Response Time with costs: {mean_response_time_plong:.2f} time units lambda: {arrival_rate}, T:{T}")

        ##########################

        # Original mean response time calculation (including the entire list)
        mean_response_time_full_list =  0 if len(response_plong_times) == 0 else (sum(response_plong_times) / len(response_plong_times))
        print(f"\nPredicted long: Mean Response Time for full list: {mean_response_time_full_list:.2f} time units")

        # Calculate mean response time with costs for the full list
        mean_response_time_with_costs_full_list =  0 if len(response_plong_times) == 0 else ((sum(response_plong_times) + n_expensive_p * (c2+c1)) / len(response_plong_times))
        print(f"\nPredicted long: Mean Response Time with costs for full list: {mean_response_time_with_costs_full_list:.2f} time units lambda: {arrival_rate}, T:{T}")

        # Calculate the index to start from (10% of the list length)
        start_index = int(len(response_plong_times) * 0.1)

        # Sliced list excluding the first 10%
        sliced_response_plong_times = response_plong_times[start_index:]

        # Calculate the mean response time without considering the first 10%
        mean_response_time_without_first_10 =  0 if len(sliced_response_plong_times) == 0 else (sum(sliced_response_plong_times) / len(sliced_response_plong_times))
        print(f"\nPredicted long: Mean Response Time without first 10%: {mean_response_time_without_first_10:.2f} time units")

        mean_response_time_with_costs_excluding_first_10 = 0 if len(sliced_response_plong_times) == 0 else ((sum(sliced_response_plong_times) + n_expensive_p * (c2+c1)) / len(sliced_response_plong_times))
        print(f"\nPredicted long: Mean Response Time with costs (excluding first 10%): {mean_response_time_with_costs_excluding_first_10:.2f} time units lambda: {arrival_rate}, T:{T}")

        mean_response_time_plong = mean_response_time_without_first_10

        print(f'response_plong_times:{response_plong_times}')
        ##########################

    if threshold == 0: #SPRPT, no short jobs
        print(f"\nMean Response Time without costs: {(sum(response_times)/ len(response_times)):.2f} time units")
        mean_response_time = (sum(response_times) + n_expensive_p * c2) / len(response_times)
        print(f"\nMean Response Time with costs: {mean_response_time:.2f} time units lambda: {arrival_rate}, T:{T}")

        print(f"\nPredicted long: Mean Response Time without costs: {(sum(response_plong_times)/ len(response_plong_times)):.2f} time units")
        mean_response_time_plong = (sum(response_plong_times) + n_expensive_p * c2) / len(response_plong_times)
        print(f"\nPredicted long: Mean Response Time with costs: {mean_response_time_plong:.2f} time units lambda: {arrival_rate}, T:{T}")


    if threshold == float('inf'): #FCFS-- no need for prediction, no long jobs
        print(f"\nMean Response Time without costs: {(sum(response_times)/ len(response_times)):.2f} time units")
        mean_response_time = (sum(response_times)) / len(response_times)
        print(f"\nMean Response Time with costs: {mean_response_time:.2f} time units lambda: {arrival_rate}, T:{T}")


        print(f"\nPredicted short: Mean Response Time without costs: {(sum(response_pshort_times)/ len(response_pshort_times)):.2f} time units")
        mean_response_time_pshort = (sum(response_pshort_times)) / len(response_pshort_times)
        print(f"\nPredicted short: Mean Response Time with costs: {mean_response_time_pshort:.2f} time units lambda: {arrival_rate}, T:{T}")

    return mean_response_time, mean_response_time_pshort, mean_response_time_plong, response_plong_times_details


def simulation_wrapper(arrival_rate, threshold, c1, c2, test_cheap_alpha, test_expen_alpha):
    global response_times, response_pshort_times, response_plong_times, short_queue, long_queue, start_time, job_processes, n_cheap_p, n_expensive_p

    response_times = []
    response_pshort_times = []
    response_plong_times = []
    short_queue = []
    long_queue = []

    # Initialize lists to store results
    mean_response_times = []
    mean_response_times_pshort = []
    mean_response_times_plong = []

    start_time = None
    job_processes = {}
    n_cheap_p = 0
    n_expensive_p = 0


    # Run the simulation 100 times
    for _ in range(100):
        mean_response_time, mean_response_time_pshort, mean_response_time_plong , response_plong_times_details = run_simulation(arrival_rate, threshold, c1, c2, test_cheap_alpha, test_expen_alpha)
        mean_response_times.append(mean_response_time)
        mean_response_times_pshort.append(mean_response_time_pshort)
        mean_response_times_plong.append(mean_response_time_plong)


    print(f"mean_response_times_plong: {mean_response_times_plong}")

    # Calculate and print the average of each
    average_mean_response_time = sum(mean_response_times) / len(mean_response_times)
    average_mean_response_time_pshort = sum(mean_response_times_pshort) / len(mean_response_times_pshort)
    average_mean_response_time_plong = sum(mean_response_times_plong) / len(mean_response_times_plong)

    print(f"Average Mean Response Time: {average_mean_response_time:.2f} time units")
    print(f"Average Predicted Short Mean Response Time: {average_mean_response_time_pshort:.2f} time units")
    print(f"Average Predicted Long Mean Response Time: {average_mean_response_time_plong:.2f} time units")


    current_date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")

    #filename = f'res/simulation_res_{current_date}.csv'

    #with open(filename, 'w') as file:
    #    file.write(f"arrival_rate:{arrival_rate}, T:{T}, c1:{c1}, c2:{c2}\n")
    #    file.write(f"mean_response_time:{mean_response_time}, mean_response_time_pshort:{mean_response_time_pshort}, mean_response_time_plong:{mean_response_time_plong}\n")
    #
    #
    # with open(f'long_res/long_job_response_{predictor}_times_{arrival_rate}_T_{threshold}_{current_date}.csv', 'w') as file:
    #     # Writing the headers
    #     file.write("Response Time,Actual Size,Predicted Size\n")
    #
    #     # Writing the job details
    #     for job_detail in response_plong_times_details:
    #         line = f"{job_detail['response_time']},{job_detail['actual_size']},{job_detail['predicted_size']}\n"
    #         file.write(line)

    return average_mean_response_time, average_mean_response_time_pshort, average_mean_response_time_plong

#####
markers = ['o', 'x', '^', 's', 'D', 'p']  # You can add more markers as needed
colors = ['b', 'g', 'r', 'c', 'm', 'y']  # Basic color abbreviations: b-blue, g-green, r-red, etc.

def test_cost_vs_cheap_alpha():
    # Test parameters
    cheap_alpha_values = [0.1, 0.3, 0.5, 0.8, 0.9, 1]
    #T_values = [0, float('inf'), 1] #TODO
    T_values = [1]
    default_arrival_rate = 0.9
    default_c1 = 0
    default_c2 = 0
    default_expen_alpha = 0.2

    labels = {0:'SPRPT', float('inf'): 'FCFS', 1:'SkipPredict'}

    results = []
    for t in T_values:
        for calpha in cheap_alpha_values:
            avg_mean, avg_pshort, avg_plong = simulation_wrapper(default_arrival_rate, t, default_c1, default_c2, calpha, default_expen_alpha)
            results.append({
                'Alg': labels[t],
                'Alpha': calpha,
                'Average Mean Response Time': avg_mean,
                'Average PShort Response Time': avg_pshort,
                'Average PLong Response Time': avg_plong,
                'Default arrival rate': default_arrival_rate,
                'Default c1': default_c1,
                'Default T': 1
            })

    # Saving the results to a CSV file
    results_df = pd.DataFrame(results)
    results_df.to_csv(f'res/cost_vs_cheapAlpha_results_{predictor}_alpha_{default_expen_alpha}_dist_{dist}_ext.csv', index=False)

    # # Plotting the results

    if PLOT_GRAPHS:
        unique_labels = results_df['Alg'].unique()

        for label, marker, color in zip(unique_labels, markers, colors):
            df_subset = results_df[results_df['Alg'] == label]
            plt.plot(df_subset['Alpha'], df_subset['Average Mean Response Time'], marker=marker, color=color,
                     label=label)

        plt.xlabel(r'Cheap $\alpha$')
        plt.ylabel('Cost')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'graphs/cost_vs_cheapAlpha_{predictor}_alpha_{default_expen_alpha}_dist_{dist}_ext.png')
        plt.clf()

def test_cost_vs_expensive_alpha():
    # Test parameters
    expensive_alpha_values = [0.1, 0.3, 0.5, 0.8, 0.9, 1]
    #T_values = [0, float('inf'), 1] #TODO
    T_values = [1]
    default_arrival_rate = 0.9
    default_c1 = 0
    default_c2 = 0
    default_cheap_alpha = 0.8

    labels = {0:'SPRPT', float('inf'): 'FCFS', 1:'SkipPredict'}

    results = []
    for t in T_values:
        for exalpha in expensive_alpha_values:
            avg_mean, avg_pshort, avg_plong = simulation_wrapper(default_arrival_rate, t, default_c1, default_c2, default_cheap_alpha, exalpha)
            results.append({
                'Alg': labels[t],
                'Alpha': exalpha,
                'Average Mean Response Time': avg_mean,
                'Average PShort Response Time': avg_pshort,
                'Average PLong Response Time': avg_plong,
                'Default arrival rate': default_arrival_rate,
                'Default c1': default_c1,
                'Default T': 1
            })

    # Saving the results to a CSV file
    results_df = pd.DataFrame(results)
    results_df.to_csv(f'res/cost_vs_expAlpha_results_{predictor}_alpha_{default_cheap_alpha}_dist_{dist}_ext.csv', index=False)

    # # Plotting the results

    if PLOT_GRAPHS:
        unique_labels = results_df['Alg'].unique()

        for label, marker, color in zip(unique_labels, markers, colors):
            df_subset = results_df[results_df['Alg'] == label]
            plt.plot(df_subset['Alpha'], df_subset['Average Mean Response Time'], marker=marker, color=color,
                     label=label)

        plt.xlabel(r'Expensive $\alpha$')
        plt.ylabel('Cost')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'graphs/cost_vs_expAlpha_{predictor}_alpha_{default_cheap_alpha}_ext.png')
        plt.clf()

def test_cost_vs_ratio():
    global real_data_index

    # Test parameters
    ratio_values = [1, 2, 3, 5, 8]
    T_values = [0, float('inf'), 1]
    default_arrival_rate = 0.9
    default_c1 = 0.5
    default_cheap_alpha = 0.8
    default_expen_alpha = 0.2

    labels = {0:'SPRPT', float('inf'): 'FCFS', 1:'SkipPredict'}

    if dist == 'real':
        T_values = [0, float('inf'), 4]
        labels = {0: 'SPRPT', float('inf'): 'FCFS', 4: 'SkipPredict'}

    results = []
    for t in T_values:
        for ratio in ratio_values:
            real_data_index = 0
            avg_mean, avg_pshort, avg_plong = simulation_wrapper(default_arrival_rate, t, default_c1, ratio * default_c1, default_cheap_alpha, default_expen_alpha)
            results.append({
                'Alg': labels[t],
                'Ratio': ratio,
                'Average Mean Response Time': avg_mean,
                'Average PShort Response Time': avg_pshort,
                'Average PLong Response Time': avg_plong,
                'Default arrival rate': default_arrival_rate,
                'Default c1': default_c1,
                'Default calpha': default_cheap_alpha,
                'Default expalpha': default_expen_alpha,
                'Default T': 1
            })

    # Saving the results to a CSV file
    results_df = pd.DataFrame(results)
    results_df.to_csv(f'res/cost_vs_ratio_results_{predictor}_alpha_{cheap_alpha}_{expensive_alpha}_dist_{dist}_dataset_{dataset}_ext.csv', index=False)

    # # Plotting the results

    if PLOT_GRAPHS:
        unique_labels = results_df['Alg'].unique()

        for label, marker, color in zip(unique_labels, markers, colors):
            df_subset = results_df[results_df['Alg'] == label]
            plt.plot(df_subset['Ratio'], df_subset['Average Mean Response Time'], marker=marker, color=color,
                     label=label)

        plt.xlabel('Prices Ratio')
        plt.ylabel('Cost')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'graphs/cost_vs_ratio_{predictor}_alpha_{cheap_alpha}_{expensive_alpha}_ext.png')
        plt.clf()


def test_cost_vs_T(given_c1 = 0.5, given_c2=2):
    global real_data_index
    T_values = [0, float('inf'), 1]
    T_test_values = [0.1, 0.5, 1, 1.5, 2, 4, 5, 8]
    default_arrival_rate = 0.9
    fixed_c1 = given_c1
    fixed_c2 = given_c2
    default_cheap_alpha = 0.8
    default_expen_alpha = 0.2

    labels = {0: 'SPRPT', float('inf'): 'FCFS', 1: 'SkipPredict'}

    if dist == 'real':
        T_values = [0, float('inf'), 4]
        labels = {0: 'SPRPT', float('inf'): 'FCFS', 4: 'SkipPredict'}

    # Running the test
    results = []
    for t in T_values:
        if t in [0, float('inf')]:
            real_data_index = 0
            average_mean_response_time, average_mean_response_time_pshort, average_mean_response_time_plong = simulation_wrapper(
                default_arrival_rate, t, fixed_c1, fixed_c2, default_cheap_alpha, default_expen_alpha)
            results.append({
                'Alg': labels[t],
                'T Value': t,
                'Average Mean Response Time': average_mean_response_time,
                'Average PShort Response Time': average_mean_response_time_pshort,
                'Average PLong Response Time': average_mean_response_time_plong,
                'Default arrival rate': default_arrival_rate,
                'Default c1': fixed_c1,
                'Default c2': fixed_c2,
                'Default calpha': default_cheap_alpha,
                'Default expalpha': default_expen_alpha
            })
        else:
            for test_value in T_test_values:
                real_data_index = 0
                average_mean_response_time, average_mean_response_time_pshort, average_mean_response_time_plong = simulation_wrapper(default_arrival_rate, test_value, fixed_c1, fixed_c2, default_cheap_alpha, default_expen_alpha)

                results.append({
                    'Alg': labels[t],
                    'T Value': test_value,
                    'Average Mean Response Time': average_mean_response_time,
                    'Average PShort Response Time': average_mean_response_time_pshort,
                    'Average PLong Response Time': average_mean_response_time_plong,
                    'Default arrival rate': default_arrival_rate,
                    'Default c1': fixed_c1,
                    'Default c2': fixed_c2,
                    'Default calpha': default_cheap_alpha,
                    'Default expalpha': default_expen_alpha
                })

    # Saving the results to a CSV file
    results_df = pd.DataFrame(results)
    results_df.to_csv(f'res/cost_vs_T_results_{predictor}_alpha_{cheap_alpha}_{expensive_alpha}_dist_{dist}_dataset_{dataset}_ext_0.8.csv', index=False)

    if PLOT_GRAPHS:
        # Plotting the results
        for (t, marker, color) in zip(T_values, markers, colors):
            label_t = labels[t]
            df_subset = results_df[results_df['Alg'] == label_t]

            plt.plot(df_subset['T Value'], df_subset['Average Mean Response Time'], marker=marker, color=color,
                     label=label_t)

        plt.xlabel(r'$T$')
        plt.ylabel('Cost')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'graphs/cost_vs_T_{predictor}_alpha_{cheap_alpha}_{expensive_alpha}_ext.png')
        plt.clf()


def test_cost_vs_arrivalrate(given_c1 = 0.5, given_c2=2):
    global real_data_index
    # Test parameters
    arrival_rate_values = [0.5, 0.6, 0.7, 0.9, 0.95]
    T_values = [0, float('inf'), 1]

    fixed_c1 = given_c1
    fixed_c2 = given_c2
    default_cheap_alpha = 0.8
    default_expen_alpha = 0.2

    labels = {0: 'SPRPT', float('inf'): 'FCFS', 1: 'SkipPredict'}

    if dist == 'real':
        T_values = [0, float('inf'), 4]
        labels = {0: 'SPRPT', float('inf'): 'FCFS', 4: 'SkipPredict'}

    # Running the test
    results = []
    for t in T_values:
        for arrival_rate in arrival_rate_values:
            real_data_index = 0
            avg_mean, avg_pshort, avg_plong = simulation_wrapper(arrival_rate, t, fixed_c1, fixed_c2, default_cheap_alpha, default_expen_alpha)
            results.append({
                'Alg': labels[t],
                'Arrival Rate': arrival_rate,
                'Average Mean Response Time': avg_mean,
                'Average PShort Response Time': avg_pshort,
                'Average PLong Response Time': avg_plong,
                'Default T': t,
                'Default c1': fixed_c1,
                'Default c2': fixed_c2,
                'Default calpha': default_cheap_alpha,
                'Default expalpha': default_expen_alpha
            })


    # Saving the results to a CSV file
    results_df = pd.DataFrame(results)
    results_df.to_csv(f'res/cost_vs_arrivalrate_results_{predictor}_alpha_{cheap_alpha}_{expensive_alpha}_dist_{dist}_dataset_{dataset}_ext.csv', index=False)

    if PLOT_GRAPHS:
        for (t, marker, color) in zip(T_values, markers, colors):
            label_t = labels[t]
            df_subset = results_df[results_df['Alg'] == label_t]

            plt.plot(df_subset['Arrival Rate'], df_subset['Average Mean Response Time'], marker=marker, color=color,
                     label=label_t)

        plt.xlabel('Arrival Rate')
        plt.ylabel('Cost')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'graphs/cost_vs_arrivalrate_{predictor}_alpha_{cheap_alpha}_{expensive_alpha}_ext.png')
        plt.clf()


import matplotlib.pyplot as plt
import pandas as pd

def test_cost_vs_alpha():
    # Test parameters
    alpha_values = [0.1, 0.3, 0.5, 0.8, 0.9, 1]
    T_values = [0.5]  # You can adjust T_values as needed
    default_arrival_rate = 0.9
    default_c1 = 0.5
    default_c2 = 2
    default_cheap_alpha = 0.8
    default_expen_alpha = 0.2

    #labels = {1: 'SkipPredict'}  # Adjust labels as needed
    labels = 'SkipPredict'
    results = []
    # Collecting data for cheap alpha
    for t in T_values:
        for calpha in alpha_values:
            avg_mean, avg_pshort, avg_plong = simulation_wrapper(default_arrival_rate, t, default_c1, default_c2, calpha, default_expen_alpha)
            results.append({
                'Alg': labels,
                'Alpha Type': 'Cheap',
                'Alpha': calpha,
                'Average Mean Response Time': avg_mean,
            })

    # Collecting data for expensive alpha
    for t in T_values:
        for exalpha in alpha_values:
            avg_mean, avg_pshort, avg_plong = simulation_wrapper(default_arrival_rate, t, default_c1, default_c2, default_cheap_alpha, exalpha)
            results.append({
                'Alg': labels,
                'Alpha Type': 'Expensive',
                'Alpha': exalpha,
                'Average Mean Response Time': avg_mean,
            })

    # Saving the results to a CSV file
    results_df = pd.DataFrame(results)
    results_df.to_csv(f'res/cost_vs_alpha_results_dist_{dist}_ext.csv', index=False)

    # Plotting the results
    if PLOT_GRAPHS:
        plt.figure()
        for alpha_type in ['Cheap', 'Expensive']:
            df_subset = results_df[results_df['Alpha Type'] == alpha_type]
            plt.plot(df_subset['Alpha'], df_subset['Average Mean Response Time'], label=f'Modify {alpha_type} Prediction')

        plt.xlabel(r'$\alpha$', fontsize = 18)
        plt.ylabel('Cost', fontsize = 18)
        plt.legend()
        plt.grid(True)
        plt.savefig('graphs/cost_vs_alpha_ext.png')
        plt.clf()


def test_cost_vs_c1():
    # Test parameters
    ratio_values = [0.1, 0.2, 0.3, 0.5, 0.8, 1]
    T_values = [0, float('inf'), 1]
    default_arrival_rate = 0.95
    default_c2 = 4
    default_cheap_alpha = 0.8
    default_expen_alpha = 0.2

    labels = {0:'SPRPT', float('inf'): 'FCFS', 1:'SkipPredict'}

    results = []
    for t in T_values:
        for ratio in ratio_values:
            avg_mean, avg_pshort, avg_plong = simulation_wrapper(default_arrival_rate, t, ratio*default_c2, default_c2, default_cheap_alpha, default_expen_alpha)
            results.append({
                'Alg': labels[t],
                'Ratio': ratio,
                'Average Mean Response Time': avg_mean,
                'Average PShort Response Time': avg_pshort,
                'Average PLong Response Time': avg_plong,
                'Default arrival rate': default_arrival_rate,
                'Default c1': ratio*default_c2,
                'Default c2': default_c2,
                'Default calpha': default_cheap_alpha,
                'Default expalpha': default_expen_alpha,
                'Default T': 1
            })

    # Saving the results to a CSV file
    results_df = pd.DataFrame(results)
    results_df.to_csv(f'res/cost_vs_c1_results_{predictor}_alpha_{cheap_alpha}_{expensive_alpha}_dist_{dist}_ext.csv', index=False)

    # # Plotting the results

    if PLOT_GRAPHS:
        unique_labels = results_df['Alg'].unique()

        for label, marker, color in zip(unique_labels, markers, colors):
            df_subset = results_df[results_df['Alg'] == label]
            plt.plot(df_subset['Ratio'], df_subset['Average Mean Response Time'], marker=marker, color=color,
                     label=label)

        plt.xlabel('c1')
        plt.ylabel('Cost')
        plt.legend()
        plt.grid(True)
        plt.savefig(f'graphs/cost_vs_c1_{predictor}_alpha_{cheap_alpha}_{expensive_alpha}_ext.png')
        plt.clf()

# def test_cost_vs_accuracy():
#     cheap_alpha_prices = {0.05: 4, 0.1: 3.5, 0.3: 2.8, 0.5: 2.3, 0.8: 2.25, 0.9: 2.2}
#     expensive_alpha_prices = {0.05: 4, 0.1: 3.5, 0.3: 2.8, 0.5: 2.3, 0.8: 2.25, 0.9: 2.2}
#
# # Test parameters
#     T_values = [1]  # Only using SkipPredict for simplicity
#     default_arrival_rate = 0.9
#     labels = {1: 'SkipPredict'}
#
#     results = []
#     for t in T_values:
#         for calpha, cprice in cheap_alpha_prices.items():
#             for exp_alpha, eprice in expensive_alpha_prices.items():
#                 avg_mean, avg_pshort, avg_plong = simulation_wrapper(default_arrival_rate, t, cprice,
#                                                                      eprice, calpha, exp_alpha)
#                 results.append({
#                     'Algorithm': labels[t],
#                     'Cheap Alpha': calpha,
#                     'Expensive Alpha': exp_alpha,
#                     'Average Mean Response Time': avg_mean,
#                     'Average PShort Response Time': avg_pshort,
#                     'Average PLong Response Time': avg_plong,
#                     'Cheap Alpha Cost': cprice,
#                     'Expensive Alpha Cost': eprice,
#                     'Default Arrival Rate': default_arrival_rate
#                 })
#
#     # Saving the results to a CSV file
#     results_df = pd.DataFrame(results)
#     filename = f'res/cost_vs_alpha_results.csv'
#     results_df.to_csv(filename, index=False)
#
#     # # Plotting the results
#
#     if PLOT_GRAPHS:
#         results_df = pd.read_csv(f'res/cost_vs_alpha_results.csv')
#
#         # Create a pivot table with Cheap Alpha as rows, Expensive Alpha as columns, and Average Mean Response Time as values
#         pivot_table = results_df.pivot(index='Cheap Alpha', columns='Expensive Alpha', values='Average Mean Response Time')
#
#         # Create the row labels with Cheap Alpha and its cost
#         row_labels = [f"{row} (c1: {results_df[results_df['Cheap Alpha'] == row]['Cheap Alpha Cost'].values[0]})" for row in
#                       pivot_table.index]
#
#         # Create the column labels with Expensive Alpha and its cost
#         col_labels = [f"{col} (c2: {results_df[results_df['Expensive Alpha'] == col]['Expensive Alpha Cost'].values[0]})" for col in
#                       pivot_table.columns]
#
#         # Create the heat matrix plot
#         plt.figure(figsize=(12, 10))
#         sns.heatmap(pivot_table, annot=True, cmap='YlGnBu', xticklabels=col_labels, yticklabels=row_labels)
#         plt.xlabel(r'Expensive $\alpha$', fontsize=14)
#         plt.ylabel(r'Cheap $\alpha$', fontsize=14)
#         plt.savefig(f'graphs/cost_vs_acc_ext_arrival_{default_arrival_rate}.png')
#         plt.clf()


def test_cost_vs_accuracy():
    cheap_acc_prices = {0.05: 0.3,  0.5: 1.3,  0.9: 3}
    expensive_alpha_prices = {0.05: 5, 0.6: 2.5, 0.9: 2.0}

# Test parameters
    T_values = [1]  # Only using SkipPredict for simplicity
    default_arrival_rate = 0.9
    labels = {1: 'SkipPredict'}

    results = []
    for t in T_values:
        for calpha, cprice in cheap_acc_prices.items():
            for exp_alpha, eprice in expensive_alpha_prices.items():
                avg_mean, avg_pshort, avg_plong = simulation_wrapper(default_arrival_rate, t, cprice,
                                                                     eprice, calpha, exp_alpha)
                results.append({
                    'Algorithm': labels[t],
                    'Cheap Alpha': calpha,
                    'Expensive Alpha': exp_alpha,
                    'Average Mean Response Time': avg_mean,
                    'Average PShort Response Time': avg_pshort,
                    'Average PLong Response Time': avg_plong,
                    'Cheap Alpha Cost': cprice,
                    'Expensive Alpha Cost': eprice,
                    'Default Arrival Rate': default_arrival_rate
                })

    # Saving the results to a CSV file
    results_df = pd.DataFrame(results)
    filename = f'res/cost_acc/ost_vs_alpha_results.csv'
    results_df.to_csv(filename, index=False)

    # # Plotting the results

    if PLOT_GRAPHS:
        results_df = pd.read_csv(f'res/cost_acc/cost_vs_alpha_results.csv')

        # Create a pivot table with Cheap Alpha as rows, Expensive Alpha as columns, and Average Mean Response Time as values
        pivot_table = results_df.pivot(index='Cheap Alpha', columns='Expensive Alpha', values='Average Mean Response Time')

        # Create the row labels with Cheap Alpha and its cost
        row_labels = [f"{row} (c1: {results_df[results_df['Cheap Alpha'] == row]['Cheap Alpha Cost'].values[0]})" for row in
                      pivot_table.index]

        col_labels = [
            f"{1 - col:.1f} (c2: {results_df[results_df['Expensive Alpha'] == col]['Expensive Alpha Cost'].values[0]})"
            for col in pivot_table.columns]

        # Create the heat matrix plot
        plt.figure(figsize=(12, 10))
        ax = sns.heatmap(pivot_table, annot=True, cmap='YlGnBu', xticklabels=col_labels, yticklabels=row_labels)
        ax.set_xticklabels(ax.get_xticklabels(), fontsize=15)
        ax.set_yticklabels(ax.get_yticklabels(), fontsize=15)

        plt.xlabel(r'$1-\alpha$ (Expensive Prediction)', fontsize=18)
        plt.ylabel(r'Probability (Accurate Cheap Prediction)', fontsize=18)
        plt.savefig(f'graphs/cost_vs_acc_ext_arrival_{default_arrival_rate}.png')
        plt.clf()


if __name__ == "__main__":
    #datasets = ['twosigma', 'google', 'mustang', 'trinity']

    datasets = ['twosigma', 'google', 'trinity']


    for dataset in datasets:
        if dataset == 'twosigma':
            file_path = 'real_dataset/jvupredict_twosigma.csv.gz'
        if dataset == 'google':
            file_path = 'real_dataset/jvupredict_google_all_features.csv.gz'
        if dataset == 'mustang':
            file_path = 'real_dataset/jvupredict_mustang_full.csv.gz'
        if dataset == 'trinity':
            file_path = 'real_dataset/jvupredict_trinity.csv.gz'

        real_data = load_and_prepare_data(file_path)

        #test_cost_vs_ratio()
        test_cost_vs_T()
        #test_cost_vs_arrivalrate()



    #simulation_wrapper(0.9, 5, 0, 0, 0, 0)

    #TESTS

    #test_cost_vs_cheap_alpha()
    #test_cost_vs_expensive_alpha()
    #test_cost_vs_alpha()


    #test_cost_vs_ratio()
    #test_cost_vs_T(3.5, 4)
    #test_cost_vs_arrivalrate(3.5, 4)

    #test_cost_vs_arrivalrate(0.5, 4)
    #test_cost_vs_c1()

    # Test it only with uniform predictor

    #use_seperate_cheapp = 1
    #predictor = 'uniP'
    #dist = 'exponential'
    #test_cost_vs_accuracy()


    #simulation_wrapper(0.9, 5, 0, 0, 0, 0)
    #simulation_wrapper(0.6, 4, 0, 0, 0, 0)
    #simulation_wrapper(0.6, float('inf'), 0, 0, 0, 0)
