use scallop_codegen::scallop;
use scallop_runtime::tags::*;

scallop! {
  Sort2 {
    decl digit(Symbol, Int).
    decl sort_2(Int).

    sort_2(0) :- digit(0, DA), digit(1, DB), DA <= DB.
    sort_2(1) :- digit(0, DA), digit(1, DB), DA > DB.
  }
}

scallop! {
  Sort3 {
    decl digit(Symbol, Int).
    decl sort_3(Int).
    decl digit_abc(Int, Int, Int).

    digit_abc(DA, DB, DC) :- digit(0, DA), digit(1, DB), digit(2, DC).

    sort_3(0) :- digit_abc(DA, DB, DC), DA <= DB, DB <= DC.
    sort_3(1) :- digit_abc(DA, DB, DC), DA <= DC, DC < DB.
    sort_3(2) :- digit_abc(DA, DB, DC), DB < DA, DA <= DC.
    sort_3(3) :- digit_abc(DA, DB, DC), DB <= DC, DC < DA.
    sort_3(4) :- digit_abc(DA, DB, DC), DC < DA, DA <= DB.
    sort_3(5) :- digit_abc(DA, DB, DC), DC < DB, DB < DA.
  }
}

scallop! {
  Sort4 {
    decl digit(Symbol, Int).
    decl sort_4(Int).
    decl digits(Int, Int, Int, Int).

    digits(D0, D1, D2, D3) :- digit(0, D0), digit(1, D1), digit(2, D2), digit(3, D3).

    sort_4(0) :- digits(D0, D1, D2, D3), D0 <= D1, D1 <= D2, D2 <= D3. // 0, 1, 2, 3
    sort_4(1) :- digits(D0, D1, D2, D3), D0 <= D1, D1 <= D3, D3 < D2. // 0, 1, 3, 2
    sort_4(2) :- digits(D0, D1, D2, D3), D0 <= D2, D2 < D1, D1 <= D3. // 0, 2, 1, 3
    sort_4(3) :- digits(D0, D1, D2, D3), D0 <= D2, D2 <= D3, D3 < D1. // 0, 2, 3, 1
    sort_4(4) :- digits(D0, D1, D2, D3), D0 <= D3, D3 < D1, D1 <= D2. // 0, 3, 1, 2
    sort_4(5) :- digits(D0, D1, D2, D3), D0 <= D3, D3 < D2, D2 < D1. // 0, 3, 2, 1

    sort_4(6) :- digits(D0, D1, D2, D3), D1 < D0, D0 <= D2, D2 <= D3. // 1, 0, 2, 3
    sort_4(7) :- digits(D0, D1, D2, D3), D1 < D0, D0 <= D3, D3 < D2. // 1, 0, 3, 2
    sort_4(8) :- digits(D0, D1, D2, D3), D1 <= D2, D2 < D0, D0 <= D3. // 1, 2, 0, 3
    sort_4(9) :- digits(D0, D1, D2, D3), D1 <= D2, D2 <= D3, D3 < D0. // 1, 2, 3, 0
    sort_4(10) :- digits(D0, D1, D2, D3), D1 <= D3, D3 < D0, D0 <= D2. // 1, 3, 0, 2
    sort_4(11) :- digits(D0, D1, D2, D3), D1 <= D3, D3 < D2, D2 < D0. // 1, 3, 2, 0

    sort_4(12) :- digits(D0, D1, D2, D3), D2 < D0, D0 <= D1, D1 <= D3. // 2, 0, 1, 3
    sort_4(13) :- digits(D0, D1, D2, D3), D2 < D0, D0 <= D3, D3 < D1. // 2, 0, 3, 1
    sort_4(14) :- digits(D0, D1, D2, D3), D2 < D1, D1 < D0, D0 <= D3. // 2, 1, 0, 3
    sort_4(15) :- digits(D0, D1, D2, D3), D2 < D1, D1 <= D3, D3 < D0. // 2, 1, 3, 0
    sort_4(16) :- digits(D0, D1, D2, D3), D2 <= D3, D3 < D0, D0 <= D1. // 2, 3, 0, 1
    sort_4(17) :- digits(D0, D1, D2, D3), D2 <= D3, D3 < D1, D1 < D0. // 2, 3, 1, 0

    sort_4(18) :- digits(D0, D1, D2, D3), D3 < D0, D0 <= D1, D1 <= D2. // 3, 0, 1, 2
    sort_4(19) :- digits(D0, D1, D2, D3), D3 < D0, D0 <= D2, D2 < D1. // 3, 0, 2, 1
    sort_4(20) :- digits(D0, D1, D2, D3), D3 < D1, D1 < D0, D0 <= D2. // 3, 1, 0, 2
    sort_4(21) :- digits(D0, D1, D2, D3), D3 < D1, D1 <= D2, D2 < D0. // 3, 1, 2, 0
    sort_4(22) :- digits(D0, D1, D2, D3), D3 < D2, D2 < D0, D0 <= D1. // 3, 2, 0, 1
    sort_4(23) :- digits(D0, D1, D2, D3), D3 < D2, D2 < D1, D1 < D0. // 3, 2, 1, 0
  }
}

fn ten_normalized_numbers() -> Vec<f32> {
  let mut numbers = (0..=9).map(|i| i as f32 * 0.1).collect::<Vec<_>>();
  let sum = numbers.iter().fold(0.0, |a, p| a + p);
  for i in 0..=9 {
    numbers[i] /= sum;
  }
  numbers
}

fn disjunction(i: usize) -> Vec<(f32, (usize, i64))> {
  ten_normalized_numbers().into_iter().enumerate().map(|(j, f)| {
    (f, (i, j as i64))
  }).collect()
}

fn main() {
  println!("SORT 2");

  let mut sort2 = Sort2::<ProbProofs>::new();

  // Initialize data
  sort2.digit().insert_disjunction(disjunction(0));
  sort2.digit().insert_disjunction(disjunction(1));

  // Execute the program
  sort2.run();

  // Investigate the results
  for elem in sort2.sort_2().complete().into_iter() {
    println!("{:?}: #Proofs = {}", elem.tup, elem.tag.proofs.len());
  }

  println!("SORT 3");

  let mut sort3 = Sort3::<ProbProofs>::new();

  // Initialize data
  sort3.digit().insert_disjunction(disjunction(0));
  sort3.digit().insert_disjunction(disjunction(1));
  sort3.digit().insert_disjunction(disjunction(2));

  // Execute the program
  sort3.run();

  // Investigate the results
  for elem in sort3.sort_3().complete().into_iter() {
    println!("{:?}: #Proofs = {}", elem.tup, elem.tag.proofs.len());
  }

  println!("SORT 4");

  let mut sort4 = Sort4::<ProbProofs>::new();

  // Initialize data
  sort4.digit().insert_disjunction(disjunction(0));
  sort4.digit().insert_disjunction(disjunction(1));
  sort4.digit().insert_disjunction(disjunction(2));
  sort4.digit().insert_disjunction(disjunction(3));

  // Execute the program
  sort4.run();

  // Investigate the results
  for elem in sort4.sort_4().complete().into_iter() {
    println!("{:?}: #Proofs = {}", elem.tup, elem.tag.proofs.len());
  }
}
