import bpy
import numpy as np
from time import perf_counter, time, sleep
from mathutils import Vector, Matrix

from .. import utils

from mathutils.geometry import (barycentric_transform,
                                intersect_point_tri,
                                intersect_point_line,
                                intersect_line_plane,
                                tessellate_polygon)

import math
import gpu
from bpy_extras import view3d_utils
from gpu_extras.batch import batch_for_shader


def raycast_objects(context, event, dg):
    """
    Execute ray cast
    return object hit, hit world location, normal of hitted face and face index
    """

    region = context.region
    rv3d = context.region_data
    coord = event.mouse_region_x, event.mouse_region_y

    view_vector = view3d_utils.region_2d_to_vector_3d(region, rv3d, coord)
    ray_origin = view3d_utils.region_2d_to_origin_3d(region, rv3d, coord)
    ray_target = ray_origin + view_vector

    def visible_objects_and_duplis():# -> Generator[tuple, Any, None]:
        """Loop over (object, matrix) pairs (mesh only)"""

        # depsgraph = dg  ## TRY to used passed depsgraph
        depsgraph = context.evaluated_depsgraph_get()
        for dup in depsgraph.object_instances:
            if dup.is_instance:  # Real dupli instance
                obj = dup.instance_object
                yield (obj, dup.matrix_world.copy())
            else:  # Usual object
                obj = dup.object
                yield (obj, obj.matrix_world.copy())

    def obj_ray_cast(obj, matrix):
        """Wrapper for ray casting that moves the ray into object space"""

        # get the ray relative to the object
        matrix_inv = matrix.inverted()
        ray_origin_obj = matrix_inv @ ray_origin
        ray_target_obj = matrix_inv @ ray_target
        ray_direction_obj = ray_target_obj - ray_origin_obj

        # cast the ray
        success, location, normal, face_index = obj.ray_cast(ray_origin_obj, ray_direction_obj)

        if success:
            return location, normal, face_index
        else:
            return None, None, None

    # cast rays and find the closest object
    best_length_squared = -1.0
    best_obj = None
    face_normal = None
    hit_loc = None
    obj_face_index = None
    for obj, matrix in visible_objects_and_duplis():
        if obj.type == 'MESH':
            hit, normal, face_index = obj_ray_cast(obj, matrix)
            if hit is not None:
                hit_world = matrix @ hit
                # scene.cursor.location = hit_world
                length_squared = (hit_world - ray_origin).length_squared
                if best_obj is None or length_squared < best_length_squared:
                    # print('length_squared',length_squared)
                    best_length_squared = length_squared
                    best_obj = obj
                    face_normal = normal
                    hit_loc = hit_world
                    obj_face_index = face_index

    if best_obj is not None:
        best_original = best_obj.original
        return best_original, hit_loc, face_normal, obj_face_index
    
    return None, None, None, None


# -----------------
### Drawing
# -----------------

def circle_2d(coord, r, num_segments):
    '''create circle, ref: http://slabode.exofire.net/circle_draw.shtml'''
    cx, cy = coord
    points = []
    theta = 2 * 3.1415926 / num_segments
    c = math.cos(theta) #precalculate the sine and cosine
    s = math.sin(theta)
    x = r # we start at angle = 0
    y = 0
    for i in range(num_segments):
        #bgl.glVertex2f(x + cx, y + cy) # output vertex
        points.append((x + cx, y + cy))
        # apply the rotation matrix
        t = x
        x = c * x - s * y
        y = s * t + c * y

    return points

def draw_callback_px(self, context):
    if context.area != self.current_area:
        return

    # 50% alpha, 2 pixel width line
    shader = gpu.shader.from_builtin('UNIFORM_COLOR')
    gpu.state.blend_set('ALPHA')
    gpu.state.line_width_set(2.0)

    view_rot = context.region_data.view_rotation

    if self.current_points:
        for coord in self.current_points:
            # circle3d = [(view_normal.to_track_quat('-Z', 'Z') @ cp) + coord for cp in self.circle_pts]
            circle3d = [(view_rot @ cp) + coord for cp in self.mini_circle_pts]
            circle3d.append(circle3d[0]) # Loop with last point
            batch = batch_for_shader(shader, 'LINE_STRIP', {"pos": circle3d})
            shader.bind()
            shader.uniform_float("color", (0.8, 0.8, 0.0, 0.9))
            batch.draw(shader)

    # Draw clicked path
    if not self.point_list:
        return

    positions = [pt['co'] for pt in self.point_list]

    line_color = color = (0.9, 0.1, 0.2, 0.8)

    ## Draw lines

    if len(self.point_list) >= 3:
        ## Duplicate first position at last to loop triangle
        positions += [self.point_list[0]['co']]
        line_color = (0.9, 0.1, 0.3, 1.0)

    batch = batch_for_shader(shader, 'LINE_STRIP', {"pos": positions})
    shader.bind()
    shader.uniform_float("color", line_color)
    batch.draw(shader)

    # Draw circle on point aligned with view

    # view_normal = context.region_data.view_matrix.inverted() @ Vector((0,0,1))
    
    for pt in self.point_list:
        coord = pt['co']

        # circle3d = [(view_normal.to_track_quat('-Z', 'Z') @ cp) + coord for cp in self.circle_pts]
        circle3d = [(view_rot @ cp) + coord for cp in self.circle_pts]
        circle3d.append(circle3d[0]) # Loop with last point
        batch = batch_for_shader(shader, 'LINE_STRIP', {"pos": circle3d})
        shader.bind()
        shader.uniform_float("color", color)
        batch.draw(shader)

    # restore opengl defaults
    gpu.state.line_width_set(1.0)
    gpu.state.blend_set('NONE')

    ## POST_PIXEL for text


# -----------------
### Operator
# -----------------

class GP_OT_bind_points(bpy.types.Operator):
    bl_idname = "gp.bind_points"
    bl_label = "Bind Points"
    bl_description = 'Bind points to use as reference for interpolation'
    bl_options = {'REGISTER', 'UNDO'}

    @classmethod
    def poll(cls, context):
        return context.object and context.object.type == 'GPENCIL'

    clear : bpy.props.BoolProperty(name='Clear', default=False, options={'SKIP_SAVE'})

    def invoke(self, context, event):
        # print('INVOKE')
        self.debug = False
        self.gp = context.object
        self.settings = context.scene.gp_interpo_settings
        self.point_list = []
        self.current_points = []
        wm = context.window_manager
        if tri_dump := wm.get(f'tri_{self.gp.name}'):
            if self.clear:
                del wm[f'tri_{self.gp.name}']
                return {'FINISHED'}

            ## load from current frame
            ## Dict to list -> cast to dict (get ID prop array)
            self.point_list = [dict(tri_dump[str(i)]) for i in range(3)]
            
            ## Update world coordinate position at current frame
            dg = bpy.context.evaluated_depsgraph_get()
            for point in self.point_list:
                ob = bpy.context.scene.objects.get(point['object'])
                ob_eval = ob.evaluated_get(dg)
                point['co'] = ob_eval.matrix_world @ ob_eval.data.vertices[point['index']].co
        
        if self.clear:
            return {'FINISHED'}
        
        ## Prepare circle 3D coordinate (create in invoke)
        self.circle_pts = [Vector((p[0], p[1], 0)) for p in circle_2d((0,0), 0.01, 12)]
        self.mini_circle_pts = [Vector((p[0], p[1], 0)) for p in circle_2d((0,0), 0.005, 12)]

        # self._timer = wm.event_timer_add(0.01, window=context.window)

        # draw in view space with 'POST_VIEW' and 'PRE_VIEW'
        self.current_area = context.area
        args = (self, context)
        self._handle = bpy.types.SpaceView3D.draw_handler_add(draw_callback_px, args, 'WINDOW', 'POST_VIEW')
        context.area.header_text_set('Bind points | Enter: Valid | Backspace: remove last point | Esc / Right-Click: Cancel')
        wm.modal_handler_add(self)
        return {'RUNNING_MODAL'}

    def exit_modal(self, context, status='INFO', text=None):
        # print('Exit modal') # Dbg
        ## Reset all drawing and report
        bpy.types.SpaceView3D.draw_handler_remove(self._handle, 'WINDOW')
        context.area.header_text_set(None)
        context.area.tag_redraw()
        if text:
            self.report({status}, text)
        else:
            ## report standard info
            self.report({'INFO'}, 'Done')

    def get_closest_vert(self, object_hit, hit_location, _normal, face_index, dg):
        
        ob_eval = object_hit.evaluated_get(dg)
        ## Get closest index on face and store
        face = ob_eval.data.polygons[face_index]
        
        ## list(dict)
        vertices_infos = [{
                            'object': object_hit.name, # store object 'name'
                            'index': vert_idx, # store vertex index
                            'co': ob_eval.matrix_world @ ob_eval.data.vertices[vert_idx].co # store initial absolute coordinate
                        }
                            # vertex: ob_eval.data.vertices[vert_idx],
                            for vert_idx in face.vertices]

        ## Filter vertices by closest to hit_location
        vertices_infos.sort(key=lambda x: (x['co'] - hit_location).length)
        return vertices_infos[0]

    def bind(self, context):
        ## store points on scene/wm properties associated with GP object
        context.window_manager[f'tri_{context.object.name}'] = {str(i) : d for i, d in enumerate(self.point_list)}
        self.exit_modal(context, text='Bound!')

    def modal(self, context, event):
        context.area.tag_redraw()
        if event.type in ('RIGHTMOUSE', 'ESC'):
            context.area.header_text_set(f'Cancelling')
            self.exit_modal(context, text='Cancelled')
            return {'CANCELLED'}

        if event.type in ('WHEELUPMOUSE', 'WHEELDOWNMOUSE', 'MIDDLEMOUSE'):
            return {'PASS_THROUGH'}

        ## disable hint if too intensive
        if event.type in {'MOUSEMOVE'}:
            ## permanent update on closest point position (too heavy to always compute ?)
            dg = bpy.context.evaluated_depsgraph_get()
            object_hit, hit_location, _normal, face_index = raycast_objects(context, event, bpy.context.evaluated_depsgraph_get())
            if object_hit is None:
                self.current_points = []
            else:
                pt = self.get_closest_vert(object_hit, hit_location, _normal, face_index, dg)
                self.current_points = [pt['co']]
            return {'PASS_THROUGH'}
        
        elif event.type in ('BACK_SPACE', 'DEL') and event.value == 'PRESS':
            if self.point_list:
                self.report({'INFO'}, 'Removed last point')
                self.point_list.pop()


        elif event.type in ('RET', 'SPACE') and event.value == 'PRESS':
            ## Valid
            if len(self.point_list) < 3:
                self.exit_modal(context, status='ERROR', text='Not enough point selected, Cancelling')
                return {'CANCELLED'}
            else:
                self.bind(context)
                return {'FINISHED'}

        elif event.type == 'LEFTMOUSE' and event.value == 'PRESS':
            if len(self.point_list) >= 3:
                # self.report({'WARNING'}, 'Already got 3 point')
                self.bind(context)
                return {'FINISHED'}
            else:
                ## Raycast surface and store point
                dg = bpy.context.evaluated_depsgraph_get()

                ## Basic rayvast (Do not consider object modifier or instance !) 
                # mouse = event.mouse_region_x, event.mouse_region_y                
                # view_mat = context.region_data.view_matrix.inverted()
                # origin = view_mat.to_translation()
                # depth3d = view_mat @ Vector((0, 0, -1))
                # point = utils.region_to_location(mouse, depth3d)
                # ray = (point - origin)
                # hit, hit_location, normal, face_index, object_hit, matrix = bpy.context.scene.ray_cast(dg, origin, ray)

                object_hit, hit_location, _normal, face_index = raycast_objects(context, event, dg)

                ## Also use triangle coordinate in triangle ?! using raycast with tesselated triangle infos
                # object_hit, hit_location, tri, tri_indices = ray_cast_point(point, origin, dg)                

                if object_hit is None:
                    self.report({'WARNING'}, 'Nothing hit, Retry on a surface')
                else:
                    # print('object_hit: ', object_hit, object_hit.is_evaluated)
                    # print('hit_location: ', hit_location)
                    # print('face_index: ', face_index)

                    # context.scene.cursor.location = hit_location # Dbg

                    ### // get vert on-place
                    # ob_eval = object_hit.evaluated_get(dg)
                    # print('ob_eval: ', ob_eval)
                    # ## Get closest index on face and store
                    # face = ob_eval.data.polygons[face_index]
                    # ## Store list of tuples [(index, world_co, object_hit), ...]
                    # vertices_infos = [(vert_idx,
                    #                     ob_eval.matrix_world @ ob_eval.data.vertices[vert_idx].co,
                    #                     object_hit) # Store original object.
                    #                 for vert_idx in face.vertices]

                    # ## Filter vedrtices by closest to hit_location
                    # vertices_infos.sort(key=lambda x: (x['co'] - hit_location).length)
                    
                    # vert = vertices_infos[0]
                    ### get vert on-place //

                    vert = self.get_closest_vert(object_hit, hit_location, _normal, face_index, dg)

                    # print('vert: ', vert)
    
                    # if self.point_list and [x for x in self.point_list if vert[0] == x[0] and vert[3] == x[3]]:
                    if any(vert['index'] == x['index'] and vert['object'] == x['object'] for x in self.point_list):
                        self.report({'WARNING'}, "Cannot use same point twice !")
                    else:
                        self.point_list += [vert]
                        self.report({'INFO'}, f"Set point {len(self.point_list)}")

        return {'RUNNING_MODAL'}

    def execute(self, context):
        return {"FINISHED"}

classes = (
    GP_OT_bind_points,
)

def register():
    for c in classes:
        bpy.utils.register_class(c)
    
        
def unregister():
    for c in reversed(classes):
        bpy.utils.unregister_class(c)