from dataclasses import dataclass
from datetime import datetime
import math
import os
os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide"
import pygame
import sys
import time

pygame.init()

width = 600
height = 900
clear_display = "0"
display_num_whole = clear_display
old_display_num_whole = display_num_whole
display_num_decimal = ""
old_display_num_decimal = display_num_decimal
num_float = False
buffer_num = ""
buffer_num_float = False
operator = ""
old_operator = ""
float_precision = 1

yellow = (240, 240, 20)
white = (255, 255, 255)
black = (0, 0, 0)
background = (80, 80, 80)

button_origin_x = 90
button_origin_y = 280
button_spacing = 140
button_radius = 50
button_background = (60, 60, 60)

@dataclass
class Button:
    row: int
    column: int
    label: str
    key: int
    is_number: bool = False
    flashing: bool = False
    unflash_on: float = 0.0
    center_x: int = 0
    center_y: int = 0

keypad_map = {
    str(pygame.K_KP_0): pygame.K_0,
    str(pygame.K_KP_1): pygame.K_1,
    str(pygame.K_KP_2): pygame.K_2,
    str(pygame.K_KP_3): pygame.K_3,
    str(pygame.K_KP_4): pygame.K_4,
    str(pygame.K_KP_5): pygame.K_5,
    str(pygame.K_KP_6): pygame.K_6,
    str(pygame.K_KP_7): pygame.K_7,
    str(pygame.K_KP_8): pygame.K_8,
    str(pygame.K_KP_9): pygame.K_9,
    str(pygame.K_KP_MINUS): pygame.K_MINUS,
    str(pygame.K_KP_MULTIPLY): pygame.K_MINUS,
    str(pygame.K_KP_DIVIDE): pygame.K_SLASH,
    str(pygame.K_KP_PLUS): pygame.K_PLUS,
    str(pygame.K_KP_PERIOD): pygame.K_PERIOD,
    str(pygame.K_KP_ENTER): pygame.K_EQUALS,
    str(pygame.K_RETURN): pygame.K_EQUALS
}
valid_keys_numbers = [
    pygame.K_0,
    pygame.K_1,
    pygame.K_2,
    pygame.K_3,
    pygame.K_4,
    pygame.K_5,
    pygame.K_6,
    pygame.K_7,
    pygame.K_8,
    pygame.K_9,
    pygame.K_KP_0,
    pygame.K_KP_1,
    pygame.K_KP_2,
    pygame.K_KP_3,
    pygame.K_KP_4,
    pygame.K_KP_5,
    pygame.K_KP_6,
    pygame.K_KP_7,
    pygame.K_KP_8,
    pygame.K_KP_9
]

valid_keys_equals = [
    pygame.K_EQUALS,
    pygame.K_RETURN,
    pygame.K_KP_ENTER
]

buttons = {
    str(pygame.K_0): Button(3, 1, "0", pygame.K_0, True),
    str(pygame.K_1): Button(2, 0, "1", pygame.K_1, True),
    str(pygame.K_2): Button(2, 1, "2", pygame.K_2, True),
    str(pygame.K_3): Button(2, 2, "3", pygame.K_3, True),
    str(pygame.K_4): Button(1, 0, "4", pygame.K_4, True),
    str(pygame.K_5): Button(1, 1, "5", pygame.K_5, True),
    str(pygame.K_6): Button(1, 2, "6", pygame.K_6, True),
    str(pygame.K_7): Button(0, 0, "7", pygame.K_7, True),
    str(pygame.K_8): Button(0, 1, "8", pygame.K_8, True),
    str(pygame.K_9): Button(0, 2, "9", pygame.K_9, True),
    str(pygame.K_PERIOD): Button(3, 2, ".", pygame.K_PERIOD),
    str(pygame.K_QUESTION): Button(3, 0, "+-", pygame.K_QUESTION),
    str(pygame.K_BACKSPACE): Button(0, 3, "<-", pygame.K_BACKSPACE),
    str(pygame.K_a): Button(1, 3, "AC", pygame.K_a),
    str(pygame.K_c): Button(2, 3, "C", pygame.K_c),
    str(pygame.K_EQUALS): Button(3, 3, "=", pygame.K_EQUALS),
    str(pygame.K_ASTERISK): Button(4, 3, "*", pygame.K_ASTERISK),
    str(pygame.K_PLUS): Button(4, 0, "+", pygame.K_PLUS),
    str(pygame.K_MINUS): Button(4, 1, "-", pygame.K_MINUS),
    str(pygame.K_SLASH): Button(4, 2, chr(247), pygame.K_SLASH)
}

def flash_button(b):
    if b.flashing == False:
        center_x = button_origin_x + (b.column * button_spacing)
        center_y = button_origin_y + (b.row * button_spacing)
        center = (center_x, center_y)
        pygame.draw.circle(screen, white, center, button_radius, 4)
        b.flashing = True
        b.unflash_on = time.time() + 0.1
    else:
        center_x = button_origin_x + (b.column * button_spacing)
        center_y = button_origin_y + (b.row * button_spacing)
        center = (center_x, center_y)
        pygame.draw.circle(screen, button_background, center, button_radius, 4)
        pygame.draw.circle(screen, yellow, center, button_radius, 2)
        b.flashing = False

def draw_buttons():

    for key, b in buttons.items():
        b.center_x = button_origin_x + (b.column * button_spacing)
        b.center_y = button_origin_y + (b.row * button_spacing)
        center = (b.center_x, b.center_y)
        pygame.draw.circle(screen, button_background, center, button_radius)
        pygame.draw.circle(screen, yellow, center, button_radius, 2)
        b_text = cheesy_font.render(b.label, True, white)
        text_x = b.center_x - (b_text.get_width() // 2)
        text_y = b.center_y - (b_text.get_height() // 2)
        if b.label == "*":
            text_y += 15
        screen.blit(b_text, (text_x, text_y))

screen = pygame.display.set_mode((width, height))
pygame.display.set_caption('CheesyCalc')
icon = pygame.image.load(os.path.join("img", "cheesy_icon_32x32.png"))
icon.convert_alpha()
pygame.display.set_icon(icon)
screen.fill(background)
pygame.display.flip()
logo = pygame.image.load(os.path.join("img", "cheesy.png"))
logo.convert_alpha()
screen.blit(logo, (30, 20))
cheesy_font = pygame.font.Font(os.path.join("font", "New Cheese.ttf"), 60)
logo_text_1 = cheesy_font.render("Cheesy", True, yellow)
logo_text_2 = cheesy_font.render("Calculator", True, yellow)
screen.blit(logo_text_1, (width // 2 - logo_text_1.get_width() // 2 + 80, 0))
screen.blit(logo_text_2, (width // 2 - logo_text_2.get_width() // 2 + 80, 45))
sevenseg_font = pygame.font.Font(os.path.join("font", "DSEG7Modern-Bold.ttf"), 60)
sevenseg_font_sm = pygame.font.Font(os.path.join("font", "Seven Segment.ttf"), 22)
draw_buttons()
pygame.display.flip()

def do_math():
    if operator == "+":
        num1 = int(buffer_num) if buffer_num_float == False else float(buffer_num)
        num1_string = display_num_whole + ("" if num_float == False else ("." + display_num_decimal))
        num2 = int(num1_string) if num_float == False else float(num1_string)
        return str(num1 + num2)

while True:
    for event in pygame.event.get():
        if event.type == pygame.QUIT:
            pygame.quit()
            sys.exit()
        mouse_click = None
        if event.type == pygame.MOUSEBUTTONDOWN:
            mouse_x, mouse_y = pygame.mouse.get_pos()
            for key, b in buttons.items():
                if math.hypot(mouse_x - b.center_x, mouse_y - b.center_y) <= button_radius:
                    mouse_click = b.key
        if event.type == pygame.KEYDOWN or not mouse_click is None:
            if mouse_click is None:
                my_key = event.key
            else:
                my_key = mouse_click
            if my_key == pygame.K_ESCAPE:
                pygame.quit()
                sys.exit()
            elif my_key == pygame.K_c:
                display_num_whole = clear_display
                display_num_decimal = ""
                num_float = False
                float_precision = 1
                flash_button(buttons[str(my_key)])
            elif my_key == pygame.K_a:
                display_num_whole = clear_display
                display_num_decimal = ""
                num_float = False
                float_precision = 1
                buffer_num = ""
                buffer_num_float = False
                operator = ""
                old_operator = ""
                flash_button(buttons[str(my_key)])
            elif my_key == pygame.K_PERIOD or my_key == pygame.K_KP_PERIOD:
                actual_key = pygame.K_PERIOD
                flash_button(buttons[str(actual_key)])
                num_float = True
            elif my_key in valid_keys_equals and not pygame.key.get_mods() & pygame.KMOD_SHIFT:
                actual_key = pygame.K_EQUALS
                flash_button(buttons[str(actual_key)])
            elif my_key == pygame.K_BACKSPACE:
                flash_button(buttons[str(my_key)])
                if num_float == True:
                    if display_num_decimal == "":
                        num_float = False
                        if len(display_num_whole) > 1:
                            display_num_whole = display_num_whole[:-1]
                        elif display_num_whole != "0":
                            display_num_whole = "0"
                    else:
                        display_num_decimal = display_num_decimal[:-1]
                else:
                    if len(display_num_whole) > 1:
                        display_num_whole = display_num_whole[:-1]
                    elif display_num_whole != "0":
                        display_num_whole = "0"
            elif (my_key == pygame.K_EQUALS and pygame.key.get_mods() & pygame.KMOD_SHIFT) or my_key == pygame.K_KP_PLUS:
                actual_key = pygame.K_PLUS
                if buffer_num == "":
                    buffer_num = display_num_whole
                    operator = "+"
                    if num_float:
                        buffer_num_float = True
                        buffer_num += "." + display_num_decimal
                    num_float = False
                    display_num_whole = clear_display
                    display_num_decimal = ""
                else:
                    sum = do_math()
                    if sum.find(".") == -1:
                        display_num_whole = sum
                        num_float = False
                    else:
                        display_num_whole, display_num_decimal = sum.split('.')
                        if display_num_decimal == "0":
                            display_num_decimal = ""
                            num_float = False
                flash_button(buttons[str(actual_key)])
            elif my_key == pygame.K_MINUS or my_key == pygame.K_KP_MINUS:
                actual_key = pygame.K_MINUS
                flash_button(buttons[str(actual_key)])
            elif (my_key == pygame.K_8 and pygame.key.get_mods() & pygame.KMOD_SHIFT) or my_key == pygame.K_KP_MULTIPLY:
                actual_key = pygame.K_ASTERISK
                flash_button(buttons[str(actual_key)])
            elif my_key == pygame.K_SLASH or my_key == pygame.K_KP_DIVIDE:
                actual_key = pygame.K_SLASH
                flash_button(buttons[str(actual_key)])
            elif my_key in valid_keys_numbers:
                if str(my_key) in keypad_map.keys():
                    actual_key = keypad_map[str(my_key)]
                else:
                    actual_key = my_key
                flash_button(buttons[str(actual_key)])
                if len(display_num_whole) <= 10:
                    if num_float == False:
                        if display_num_whole == clear_display and actual_key != pygame.K_0:
                            display_num_whole = chr(actual_key)
                        elif display_num_whole != clear_display:
                            display_num_whole += chr(actual_key)
                    else:
                        display_num_decimal += chr(actual_key)
    if display_num_whole != old_display_num_whole or display_num_decimal != old_display_num_decimal:
        #print(f"Number: {display_num_whole}", end='')
        #if num_float == True:
        #    print(".", end='')
        #print(display_num_decimal)
        old_display_num_whole = display_num_whole
        old_display_num_decimal = display_num_decimal
    pygame.draw.rect(screen, black, pygame.Rect(10, 100, width - 20, 110))
    pygame.draw.rect(screen, yellow, pygame.Rect(10, 100, width - 20, 110), 2)
    main_text = sevenseg_font.render(display_num_whole + "." + display_num_decimal, True, white)
    screen.blit(main_text, (width - 20 - main_text.get_width(), 130))
    buffer_text = sevenseg_font_sm.render(buffer_num + " " + operator, True, white)
    screen.blit(buffer_text, (width - 20 - buffer_text.get_width(), 130 - 4 - buffer_text.get_height()))
    for key, b in buttons.items():
        if b.flashing == True and time.time() >= b.unflash_on:
            flash_button(b)
    pygame.display.flip()