import argparse
import os
import io
import subprocess
import re
import csv
from tqdm.auto import tqdm
import pandas as pd
from datasets import Dataset
import chess
import chess.pgn


# Parse game URL
def parse_game_url(url):
    # Pattern to extract game hash and step number, accounting for possible URL structures
    pattern = re.compile(r"lichess\.org/([\w]+)(?:/black|/white)?#?(\d*)")
    match = pattern.search(url)
    if match:
        game_hash = match.group(1)
        num_moves = int(match.group(2)) + 1 if match.group(2) else None
        return game_hash, num_moves
    else:
        raise ValueError(f"Invalid URL: {url}")


def load_puzzle_csv(csv_file_path):
    puzzles = {}
    # Read the CSV file and populate the dictionary with headers
    with open(csv_file_path, mode="r") as file:
        csv_reader = csv.reader(file)
        headers = next(csv_reader)  # Read the first row to get the headers
        puzzles = {
            "GameHash": [],
            "NumMoves": [],
            "PrevFEN": [],
            "LastMoveUCI": [],
            "AnswerUCI": [],
        } | {
            header: [] for header in headers if header not in ["Moves", "FEN"]
        }  # Initialize dictionary with headers
        # Headers: PuzzleId,FEN,Moves,Rating,RatingDeviation,Popularity,NbPlays,Themes,GameUrl,OpeningTags
        # Populate the dictionary with the rest of the data
        for row in tqdm(
            csv.DictReader(file, fieldnames=headers),
            desc="Reading puzzles CSV",
            total=int(subprocess.check_output(["wc", "-l", csv_file_path]).split()[0])
            - 1,
        ):
            if len(row["Moves"].split(" ")) != 2:
                continue
            game_hash, num_moves = parse_game_url(row["GameUrl"])
            puzzles["GameHash"].append(game_hash)
            puzzles["NumMoves"].append(num_moves)
            for key, value in row.items():
                if key == "Moves":
                    puzzles["LastMoveUCI"].append(row["Moves"].split(" ")[0])
                    puzzles["AnswerUCI"].append(row["Moves"].split(" ")[1])
                elif key == "FEN":
                    puzzles["PrevFEN"].append(row["FEN"])
                elif key in ["Themes", "OpeningTags"]:
                    puzzles[key].append(value.split(" "))
                elif key in ["Rating", "RatingDeviation", "Popularity", "NbPlays"]:
                    puzzles[key].append(int(value))
                else:
                    puzzles[key].append(value)
    return puzzles


# Load from PGN file
def load_pgn_file(pgn_file_path):
    pgns = []  # Initialize an empty list to hold the pgn strings
    pgn = []  # Initialize an empty list to hold the lines of a pgn block
    empty_line_seen = False  # Flag to indicate an empty line within a pgn block
    with open(pgn_file_path, "r") as file:
        for line in tqdm(
            file,
            desc="Reading PGN file",
            total=int(subprocess.check_output(["wc", "-l", pgn_file_path]).split()[0]),
        ):
            if line.strip():
                pgn.append(line.rstrip())  # Add non-empty line to the current pgn block
            else:
                if not empty_line_seen:  # If first empty line in a pgn block
                    pgn.append("")  # Add the empty line to the pgn block
                    empty_line_seen = True  # Set flag to indicate empty line seen
                elif pgn:  # If an empty line is seen again, it's end of a pgn block
                    pgns.append("\n".join(pgn))  # Add the pgn block to the list
                    pgn = []  # Reset for the next pgn block
                    empty_line_seen = False  # Reset flag for the new pgn block
        if pgn:  # Add the last pgn block if it's not added yet
            pgns.append("\n".join(pgn))
    # Create a dictionary with game hashes as keys and pgns as values
    pgns = {
        re.search(r'\[Site "https://lichess.org/([a-zA-Z0-9]+)"\]', pgn).group(1): pgn
        for pgn in pgns
        if '[Variant "Standard"]' in pgn
    }
    return pgns


def parse_pgn(
    complete_annotated_pgn,
    previous_fen,
    last_move_uci,
    answer_uci,
    num_moves,
):
    try:
        # Determine the puzzle player from number of moves
        puzzle_player = "White" if num_moves % 2 == 0 else "Black"
        # Load the game using python-chess
        complete_game = chess.pgn.read_game(io.StringIO(complete_annotated_pgn))
        # Convert to simplified PGN string
        complete_simplified_pgn = complete_game.accept(
            chess.pgn.StringExporter(headers=False, variations=False, comments=False)
        )
        # Initialize an empty board to play moves on
        board = chess.Board()
        ucis = []
        sans = []
        # Construct the prompt by playing the moves on the board
        for move_num, (move, node) in enumerate(
            zip(complete_game.mainline_moves(), complete_game.mainline()), 1
        ):
            # Stop if we've reached the desired number of moves minus one
            if move_num > num_moves - 1:
                break
            ucis.append(move.uci())
            sans.append(node.san())
            board.push(move)
            assert board.is_valid()
        # Sanity check FEN
        assert board.fen() == previous_fen
        # Apply last move
        ucis.append(move.uci())
        sans.append(node.san())
        board.push(move)
        assert board.is_valid()
        # Get the current FEN
        current_fen = board.fen()
        # Sanity check last move UCI and export
        assert ucis[-1] == last_move_uci
        current_ucis = ", ".join(ucis)
        # Export the SANs
        current_sans = ", ".join(sans)
        # Get current PGNs
        game = chess.pgn.Game.from_board(chess.Board())
        game.add_line(
            list(complete_game.mainline_moves())[:num_moves],
        )
        for node, comment in zip(
            game.mainline(),
            [n.comment for n in complete_game.mainline()][:num_moves],
        ):
            node.comment = comment
        current_annotated_pgn = game.accept(
            chess.pgn.StringExporter(headers=False, variations=False, comments=True)
        )
        current_annotated_pgn = re.sub(r"\n", "", current_annotated_pgn)
        assert current_annotated_pgn[-1] == "*"
        current_annotated_pgn = current_annotated_pgn[:-1] + " *"
        current_simplified_pgn = game.accept(
            chess.pgn.StringExporter(headers=False, variations=False, comments=False)
        )
        # Apply the answer move
        answer_move = board.parse_uci(answer_uci)
        board.push(answer_move)
        assert board.is_valid()
        answer_node = list(chess.pgn.Game.from_board(board).mainline())[-1]
        answer_san = answer_node.san()
        # Return the new fields
        return {
            "PuzzlePlayer": [puzzle_player],
            "CompSimpPGN": [complete_simplified_pgn],
            "CurAnnoPGN": [current_annotated_pgn],
            "CurSimpPGN": [current_simplified_pgn],
            "CurFEN": [current_fen],
            "CurUCIs": [current_ucis],
            "CurSANs": [current_sans],
            "AnswerSAN": [answer_san],
        }
    except Exception:
        return {
            "PuzzlePlayer": [],
            "CompSimpPGN": [],
            "CurAnnoPGN": [],
            "CurSimpPGN": [],
            "CurFEN": [],
            "CurUCIs": [],
            "CurSANs": [],
            "AnswerSAN": [],
        }


def main(pgn_file_path, puzzle_csv_path, dataset_save_path, chunk_index, chunk_size):
    if os.path.exists(os.path.join(dataset_save_path, f"chunk_{chunk_index}")):
        print(f"Chunk {chunk_index} already exists, skipping.")
        return

    pgns = load_pgn_file(pgn_file_path)
    puzzles = load_puzzle_csv(puzzle_csv_path)
    puzzles["CompAnnoPGN"] = [pgns[game_hash] for game_hash in puzzles["GameHash"]]
    df = pd.DataFrame(puzzles)
    ds = Dataset.from_pandas(df)
    ds = ds.select(
        range(chunk_index * chunk_size, min((chunk_index + 1) * chunk_size, len(ds)))
    )
    ds = ds.map(
        lambda example: parse_pgn(
            example["CompAnnoPGN"][0],
            example["PrevFEN"][0],
            example["LastMoveUCI"][0],
            example["AnswerUCI"][0],
            example["NumMoves"][0],
        ),
        num_proc=4,
        batched=True,
        batch_size=1,
        desc="Parsing PGNs",
    )
    ds.save_to_disk(os.path.join(dataset_save_path, f"chunk_{chunk_index}"))


if __name__ == "__main__":
    # Set up argument parsing
    argparser = argparse.ArgumentParser(
        description="Download PGN files with backup and concurrency control."
    )
    argparser.add_argument(
        "--pgn_file_path",
        type=str,
        default="./data/Lichess/one_move_puzzle.pgn",
        help="Path to the PGN file",
    )
    argparser.add_argument(
        "--puzzle_csv_path",
        type=str,
        default="./data/Lichess/lichess_db_puzzle.csv",
        help="Path to the CSV file containing the puzzles",
    )
    argparser.add_argument(
        "--dataset_save_path",
        type=str,
        default="./data/Lichess/chunks/",
        help="Path to the CSV file containing the puzzles",
    )
    argparser.add_argument(
        "--chunk_index",
        type=int,
        help="Chunk index to process",
    )
    argparser.add_argument(
        "--chunk_size",
        type=int,
        default=20000,
        help="Chunk size to process",
    )
    args = argparser.parse_args()
    main(
        args.pgn_file_path,
        args.puzzle_csv_path,
        args.dataset_save_path,
        args.chunk_index,
        args.chunk_size,
    )
