import json
from playwright.sync_api import sync_playwright
from pathlib import Path
from browser_env.processors import ObservationHandler, TextObervationProcessor
import math
import random
import re



def overlap(box1, box2):
    x1, y1, width1, height1 = box1
    x2, y2, width2, height2 = box2

    x1_br = x1 + width1
    y1_br = y1 + height1
    x2_br = x2 + width2
    y2_br = y2 + height2

    if x1_br <= x2 or x2_br <= x1 or y1_br <= y2 or y2_br <= y1:
        return False
    else:
        return True
    
actn_dict = {
    'click':'click',
    'hover':'hover',
    'type':'type',
    'press':'key_press',
    'goto':'goto',
    'go_forward':'go_forward',
    'new_tab':'new_tab',
    'close_tab':'close_tab',
    'switch_tab':'switch_tab'
}
    
with sync_playwright() as p:
	browser = p.chromium.launch()  
	with open('wikihow_0501.jsonl', 'r') as in_file:
		for line in in_file:
			try:
				page = browser.new_page()
				cdp_client = page.context.new_cdp_session(page)
				data = json.loads(line)

				action_string = data['next_action']
				lines = action_string.split('\n')
				filtered_lines = [line for line in lines if not line.strip().startswith('#')]
				action_string = '\n'.join(filtered_lines)

				prev_action = data['prev_actions']
				prev_actn_lines = prev_action.split('\n')
				rand_int = random.randint(1, 10000)
				# filtered_prev_actn_lines = [line for line in prev_actn_lines if not line.strip().startswith('#')]
				
				for i in range(len(prev_actn_lines)):
					if prev_actn_lines[i].strip().startswith('#'):
						rand_int = random.randint(1, 10000)
						continue
					if 'click_and_type' in prev_actn_lines[i]:
						type_content = re.search(r',\s*(.*?)\)', prev_actn_lines[i]).group(1)
						type_content = re.sub(r'[^a-zA-Z ]', '', type_content).strip()
						prev_actn_lines[i] = f'type(element_id="{str(rand_int)}","{type_content}")'
					elif 'click' in prev_actn_lines[i]:
						prev_actn_lines[i] = f'click(element_id="{str(rand_int)}")'
					elif 'hover' in prev_actn_lines[i]:
						prev_actn_lines[i] = f'hover(element_id="{str(rand_int)}")'
				data['prev_actions'] = '\n'.join(prev_actn_lines)

				with open('tmp.html', 'w') as tmp_file:
					tmp_file.write(data['html'])
				file_path = Path('tmp.html').resolve()
				print(file_path)
				page.goto(f"file://{file_path}")
				obs_handler = ObservationHandler(
					"text",
					"accessibility_tree",
					"",
					False,
					{"width": 1280, "height": 1080},
				)
				cdp_client.send(
					"Accessibility.enable", {}
				)

				obs = obs_handler.get_observation(page, cdp_client)
				obs_metadata = obs_handler.get_observation_metadata()
				tree = obs_handler.text_processor.d_tree
				document = tree["documents"][0]
				strings = tree["strings"]
				if "id" not in strings:
					continue
				tgt_idx = strings.index("id")
				nodes = document["nodes"]
				backend_node_ids = nodes["backendNodeId"]
				node_names = nodes["nodeName"]
				node_types = nodes["nodeType"]
				attributes = nodes["attributes"]

				backend_node_id = -1
				bound = []
				for idx in range(len(node_names)):
					if tgt_idx in attributes[idx]:
						action_uid_idx = attributes[idx][attributes[idx].index(tgt_idx) + 1]
						if strings[action_uid_idx] == 'next-action-target-element':
							backend_node_id = backend_node_ids[idx]
							# bound = bounds[idx]
							# print(
							#     strings[node_names[idx]],
							#     [strings[x] for x in attributes[idx]],
							#     node_types[idx],
							# )
							break
				if backend_node_id == -1:
					print("can't find backendnodeid")
				print("backend_node_id", backend_node_id)

				if backend_node_id != -1:
					axt_node_id = cdp_client.send(
						"Accessibility.getPartialAXTree", {'backendNodeId': backend_node_id}
					)['nodes'][0]['nodeId']

					print('axt_node_id', axt_node_id)
					rect_response = TextObervationProcessor.get_bounding_client_rect(cdp_client, backend_node_id)
					if rect_response.get("result", {}).get("subtype", "") == "error":
						print("get_bounding_client_rect error")
					else:
						x = rect_response["result"]["value"]["x"]
						y = rect_response["result"]["value"]["y"]
						width = rect_response["result"]["value"]["width"]
						height = rect_response["result"]["value"]["height"]
						print(x,y,width,height)

					tag_union_bound = [x,y,width,height]

					ax_tree = obs_handler.text_processor.pre_prune_axt
					
					if len(tag_union_bound) == 0:
						# print(ax_tree)
						print(backend_node_id)
						print("can't find node in pre prune tree")

					if len(tag_union_bound) != 0:
						candidate_bounds = []
						candidate_node_ids = []

						for node_id, node_info in obs_metadata["text"][
							"obs_nodes_info"
						].items():
							can_union_bound = node_info["union_bound"]
							if overlap(can_union_bound, tag_union_bound):
								# print(node_id)
								# print("can_union_bound", can_union_bound)
								# print("tag_union_bound", tag_union_bound)
								candidate_bounds.append(can_union_bound)
								candidate_node_ids.append(node_id)
						target_x, target_y, _, _ = tag_union_bound

						min_distance = float("inf")
						tgt_list_idx = None

						if len(candidate_bounds) == 0:
							#print(ax_tree)
							#print(backend_node_id)
							print("no overlap union bound")
						if len(candidate_bounds) != 0:
							for idx in range(len(candidate_bounds)):
								bound = candidate_bounds[idx]
								x, y, _, _ = bound
								distance = math.sqrt(
									(target_x - x) ** 2 + (target_y - y) ** 2
								)

								if distance < min_distance:
									min_distance = distance
									tgt_list_idx = idx

							axt_nodeid = candidate_node_ids[tgt_list_idx]


							print(axt_nodeid)
							# print(obs_metadata["text"][
							#     "obs_nodes_info"
							# ][axt_nodeid])
							# print(obs)
							# print(axt_nodeid in obs)
							# assert False
							# print('axt_nodeid',axt_nodeid)
							# print('obs',obs)
						

				cdp_client.detach()
				page.close()
				data['axt'] = obs['text']
				data['axt_nodeid'] = axt_nodeid
				actn = action_string
				print('\n\naction',actn)
				prev_actn_lines = data['prev_actions'].split('\n')
				prev_actn_lines = [x for x in prev_actn_lines if len(x) > 1 and '# example' not in x]
				
				data['prev_actions'] = '\n'.join(prev_actn_lines)

				
				actn_comment = '\n'.join(data['next_action'].split('\n')[:-1])
				if 'click_and_type' in actn:
					type_content = re.search(r',\s*(.*?)\)', actn).group(1)
					type_content = re.sub(r'[^a-zA-Z ]', '', type_content).strip()
					actn = actn_comment + '\n'+ f'type(element_id="{str(axt_nodeid)}",string="{type_content}")'
				elif 'click' in actn:
					actn = actn_comment + '\n'+ f'click(element_id="{str(axt_nodeid)}")'
				elif 'hover' in actn:
					actn = actn_comment + '\n'+ f'hover(element_id="{str(axt_nodeid)}")'
				else:
					actn = data['next_action']
				actn = '\n'.join([x for x in actn.split('\n') if '# example' not in x])
				data['next_action'] = actn

				print('next_action',data['next_action'])
				print('prev_actions',data['prev_actions'])
				with open('parsed_wikihow_0501.jsonl', "a") as jsonl_file:
					jsonl_file.write(json.dumps(data) + "\n")
			except:
				pass
