from time import time
import re

import bpy
from mathutils import Vector

from mathutils.geometry import (barycentric_transform,
                                intersect_line_plane)

from ..utils import (plane_on_bone,
                    ray_cast_point,
                    obj_ray_cast,
                    triangle_normal,
                    search_square,
                    get_gp_draw_plane,
                    create_plane,
                    following_keys)


class GP_OT_pick_collection(bpy.types.Operator):
    bl_idname = "gp.pick_collection"
    bl_label = "Auto Pick Mesh Collection"
    #bl_description = 'Pick Mesh Collection'
    bl_options = {'REGISTER', 'UNDO'}

    @classmethod
    def poll(cls, context):
        return context.active_object and context.object.type == 'GPENCIL'

    def execute(self, context):
        parts = re.split(r'(-|_)', context.object.name)

        for col in context.scene.collection.children_recursive:                    
            for i in range(len(parts), -1, -1):
                base = ''.join(parts[:i])
                tgt_name = f'{base}_mesh'
                if tgt_name == col.name:
                    print(f'Automatic set "{tgt_name}"')
                    context.scene.gp_interpo_settings.target_collection = col
                    return {"FINISHED"}
        
        self.report({"WARNING"}, 'No Collection found')
        return {"CANCELLED"}


class GP_OT_interpolate_stroke_base(bpy.types.Operator):
    bl_idname = "gp.interpolate_stroke_base"
    bl_label = "Interpolate Stroke"
    bl_description = 'Interpolate Stroke based on user bound triangle'
    bl_options = {'REGISTER', 'UNDO'} 

    interactive : bpy.props.BoolProperty(name='Interactive', default=False, options={'SKIP_SAVE'})
    
    next : bpy.props.BoolProperty(name='Next', default=True, options={'SKIP_SAVE'})

    @classmethod
    def poll(cls, context):
        if context.active_object and context.object.type == 'GPENCIL'\
            and context.mode in ('EDIT_GPENCIL', 'SCULPT_GPENCIL', 'PAINT_GPENCIL'):
            return True
        cls.poll_message_set("Need a Grease pencil object in Edit or Sculpt mode")
        return False

    @classmethod
    def description(cls, context, properties):
        if properties.interactive:
            return "Interactive interpolate mode\
                    \nUse Left <- -> Right keys\
                    \n+Ctrl to jump over key interpolated during modal"
        if properties.next:
            return "Interpolate Stroke Forward"
        else:
            return "Interpolate Stroke Backward"

    def apply_and_store(self, attrs):
        '''individual item in attrs: (prop, attr, [new_val])'''
        for item in attrs:
            prop, attr = item[:2]
            self.stored_attrs.append( (prop, attr, getattr(prop, attr)) )
            if len(item) >= 3:
                setattr(prop, attr, item[2])
    
    def restore(self):
        for prop, attr, old_val in self.stored_attrs:
            setattr(prop, attr, old_val)

    def exit(self, context, status='INFO', text=None, cancelled=False):
        context.area.header_text_set(None)
        wm = context.window_manager
        if self.report_progress:
            wm.progress_end() # Pgs
        if self.timer:
            wm.event_timer_remove(self.timer)
        self.restore()
        if self.debug:
            if self.scan_time is not None:
                print(f"Paste'n'place time {time()-self.start - self.scan_time}s")
        else:
            if self.settings.method == 'BONE':
                ## Remove Plane and it's collection after use
                if self.plane is not None:
                    bpy.data.objects.remove(self.plane)
                if self.tool_col is not None:
                    bpy.data.collections.remove(self.tool_col)

        cancel_state = '(Stopped!) ' if cancelled else ''
        mess = f'{cancel_state}{self.loop_count} interpolated frame(s) ({time()-self.start:.3f}s)'
        
        if text:
            print(mess)
            self.report({status}, text)
        else:
            self.report({'INFO'}, mess)
        
        if status == 'INFO':
            return {'FINISHED'}
        return {'CANCELLED'}

    # def get_stroke_to_interpolate(self, context):
    #     ## Get strokes to interpolate
    #     #tgt_strokes = [s for s in self.gp.data.layers.active.active_frame.strokes if s.select]
    #     tgt_strokes = [s for l in self.layers for s in l.active_frame.strokes if s.select]

    #     ## If nothing selected in sculpt/paint, Select all before triggering
    #     if not tgt_strokes and context.mode in ('SCULPT_GPENCIL', 'PAINT_GPENCIL'):
    #         for s in self.gp.data.layers.active.active_frame.strokes:
    #             s.select = True
    #         tgt_strokes = self.gp.data.layers.active.active_frame.strokes

    #     if tgt_strokes:            
    #         return tgt_strokes
        
    #     return self.exit(context, status='ERROR', text='No stroke selected!')
            

    ## Added to operators owns invoke with uper().invoke(context, event)
    def invoke(self, context, event):
        self.debug = False
        self.stored_attrs = [] # context manager store/restore
        self.loop_count = 0 # frames list iterator
        self.start = time()
        self.scan_time = None # to print time at exit in debug mode
        self.plane = None # 3D Plane for bone interpolation
        self.tool_col = None # collection containing 3D plane
        self.gp = context.object
        self.settings = context.scene.gp_interpo_settings
        self.frames_to_jump = []
        self.cancelled = False
        self.timer = None
        self.timer_event = 'TIMER'
        self.report_progress = (self.settings.use_animation and not self.interactive)
        self.interpolated_keys = {context.scene.frame_current}

        ## Remove interpolation_plane collection ! (unseen, but can be hit)
        if interp_plane := bpy.data.objects.get('interpolation_plane'):
            bpy.data.objects.remove(interp_plane)
        if interp_col := bpy.data.collections.get('interpolation_tool'):
            bpy.data.collections.remove(interp_col)

        if context.mode != 'EDIT_GPENCIL':
            self.report({"ERROR"}, "Mode need to be Edit Grease Pencil")
            return {"CANCELLED"}

        ## Change active layer if strokes are selected only on this layer
        self.layers = [l for l in self.gp.data.layers 
                  if (not l.lock and l.active_frame and not l.hide)
                  and next((s for s in l.active_frame.strokes if s.select), None)]
        
        self.strokes = [s for l in self.layers for s in l.active_frame.strokes if s.select]
        if not self.strokes:
            self.report({"ERROR"}, "No strokes selected")
            return {"CANCELLED"}
        
        #if not self.layers:
        #    return self.exit(context, status='ERROR', text='No stroke selected!')
            
        #elif len(layers) > 1:
        #    return self.exit(context, status='ERROR', text='Strokes selected accross multiple layers!')
        
        ## Set active layer
        #self.gp.data.layers.active = layers[0]
        
        if self.interactive:
            self.frames_to_jump = following_keys(forward=True, animation=True)
            self.frames_to_jump += following_keys(forward=False, animation=True)
            self.frames_to_jump.append(context.scene.frame_current)
            self.frames_to_jump.sort()
            context.area.header_text_set('Frame interpolation < jump with left-right arrow keys > | Esc/Enter: Stop') # (+Ctrl to skip all already interpolated)
        else:
            ## Determine on what key/keys to jump
            self.frames_to_jump = following_keys(forward=self.next, animation=self.settings.use_animation or self.interactive)
            if not len(self.frames_to_jump):
                return self.exit(context, status='WARNING', text='No keyframe available in this direction')
            
            # TODO: Expose timer (in preferences ?) to let user more time to see result between frames
            self.timer = context.window_manager.event_timer_add(0.04, window=context.window)

        if self.report_progress:
            context.window_manager.progress_begin(self.frames_to_jump[0], self.frames_to_jump[-1]) # Pgs


    def modal(self, context, event):
        scn = context.scene

        if event.type in {'RIGHTMOUSE', 'ESC', 'RET'}:
            return self.exit(context, status='WARNING', text='Cancelling', cancelled=True)

        if self.interactive:
            frame = None
            current_frame = context.scene.frame_current
            self.loop_count = 0 # Reset to keep inifinite loop
            if event.type == 'LEFT_ARROW' and event.value == 'PRESS':
                if event.ctrl:
                    frame = next((f for f in self.frames_to_jump[::-1] if f < current_frame and f not in self.interpolated_keys), None)
                else:
                    frame = next((f for f in self.frames_to_jump[::-1] if f < current_frame), None)
            if event.type == 'RIGHT_ARROW' and event.value == 'PRESS':
                if event.ctrl:
                    frame = next((f for f in self.frames_to_jump if f > current_frame and f not in self.interpolated_keys), None)
                else:
                    frame = next((f for f in self.frames_to_jump if f > current_frame), None)

            if (event.type in ('LEFT_ARROW', 'RIGHT_ARROW') and event.value == 'PRESS') and frame is None:
                self.report({'WARNING'}, 'No frame to jump to in this direction!')

        else:
            frame_num = len(self.frames_to_jump)
            percentage = (self.loop_count) / (frame_num) * 100
            context.area.header_text_set(f'Interpolation {percentage:.0f}% {self.loop_count + 1}/{frame_num} | Esc: Cancel')

        ## -- Enter if LOOPTIMER or INTERACTIVE left-right shortcut
        if event.type == self.timer_event or (self.interactive and frame is not None):
            if not self.interactive:
                frame = self.frames_to_jump[self.loop_count]
            scn.frame_set(frame)
            if frame in self.interpolated_keys:
                 self.report({'INFO'}, f'SKIP {frame} (already interpolated)')
                 return {'RUNNING_MODAL'}
            print(f'-> {frame}')
            if self.report_progress:
                context.window_manager.progress_update(frame) # Pgs

            ## Interpolate function
            self.interpolate_frame(context)
            
            if self.interactive:
                self.interpolated_keys.add(frame)
            else:
                self.loop_count += 1
                if self.loop_count >= len(self.frames_to_jump):
                    return self.exit(context)

                # bpy.ops.wm.redraw_timer(type='DRAW_WIN_SWAP', iterations=1)

        return {'RUNNING_MODAL'}

    def interpolate_frame(self, context):
        raise Exception('Not Implemented')


## Converted to modal from "operator_single"

class GP_OT_interpolate_stroke(GP_OT_interpolate_stroke_base):
    bl_idname = "gp.interpolate_stroke"
    bl_label = "Interpolate Stroke"
    bl_description = 'Interpolate Stroke'
    bl_options = {'REGISTER', 'UNDO'} 

    def iterative_search(self, context, obj, coord, origin, dg):
        '''Search geometry for outside point (where raycast did not hit any geometry)
        
        return :
            object_hit, hit_location, tri, tri_indices.
            None if nothing found
        '''

        for iteration in range(1, 10):
            for square_co in search_square(coord, factor=self.settings.search_range * iteration):

                if obj:
                    object_hit, hit_location, tri, tri_indices = obj_ray_cast(obj, square_co, origin, dg)
                else:
                    object_hit, hit_location, tri, tri_indices = ray_cast_point(square_co, origin, dg)

                if object_hit:
                    ## On location coplanar with face triangle
                    # hit_location = intersect_line_plane(origin, coord, hit_location, triangle_normal(*tri))

                    ## On view plane 
                    view_vec = context.scene.camera.matrix_world.to_quaternion() @ Vector((0,0,1))
                    hit_location = intersect_line_plane(origin, coord, hit_location, view_vec)

                    ## An average of the two ?
                    # hit_location_1 = intersect_line_plane(origin, coord, hit_location, triangle_normal(*tri))
                    # hit_location_2 = intersect_line_plane(origin, coord, hit_location, view_vec)
                    # hit_location = (hit_location_1 + hit_location_2) / 2

                    return object_hit, hit_location, tri, tri_indices 

        return None, None, None, None

    def invoke(self, context, event):
        if state := super().invoke(context, event):
            return state

        scn = context.scene

        origin = scn.camera.matrix_world.to_translation()
        
        strokes = [s for l in self.layers for s in l.active_frame.strokes if s.select]
        if not strokes:
            self.report({"ERROR"}, "No strokes selected")
            return {"CANCELLED"}

        col = self.settings.target_collection
        if not col:
            col = scn.collection

        included_cols = [c.name for c in self.gp.users_collection]
        target_obj = None
        
        ## Setup depending on method
        if self.settings.method == 'BONE':
            if not self.settings.target_rig or not self.settings.target_bone:
                return self.exit(context, status='ERROR', text='No Bone selected')

            included_cols.append('interpolation_tool')

            ## Ensure collection and plane exists
            # get/create collection
            self.tool_col = bpy.data.collections.get('interpolation_tool')
            if not self.tool_col:
                self.tool_col = bpy.data.collections.new('interpolation_tool')

            if self.tool_col.name not in bpy.context.scene.collection.children:
                bpy.context.scene.collection.children.link(self.tool_col)
                self.tool_col.hide_viewport = True # needed ?
            
            # get/create meshplane
            self.plane = bpy.data.objects.get('interpolation_plane')
            if not self.plane:
                self.plane = create_plane(name='interpolation_plane')
                self.plane.select_set(False)

            if self.plane.name not in self.tool_col.objects:
                self.tool_col.objects.link(self.plane)
            target_obj = self.plane

        elif self.settings.method == 'GEOMETRY':
            if col != context.scene.collection:
                included_cols.append(col.name)
        
        elif self.settings.method == 'OBJECT':
            if not self.settings.target_object:
                return self.exit(context, status='ERROR', text='No Object selected')
                
            col = scn.collection # Reset collection filter
            target_obj = self.settings.target_object
            if target_obj.library:
                ## Look if an override exists in scene to use instead of default object
                if (override := next((o for o in scn.objects if o.name == target_obj.name and o.override_library), None)):
                    target_obj = override

        ## Prepare context manager
        attrs = [
            # (context.view_layer.objects, 'active', self.gp),
            (context.tool_settings, 'use_keyframe_insert_auto', True),
            # (bpy.context.scene.render, 'simplify_subdivision', 0),
            ]

        ## Set everything in SETUP list
        self.apply_and_store(attrs)

        if self.settings.method == 'BONE':
            ## replace plane
            _bone_plane = plane_on_bone(self.settings.target_rig.pose.bones.get(self.settings.target_bone),
                                        arm=self.settings.target_rig,
                                        set_rotation=self.settings.use_bone_rotation,
                                        mesh=True)

            ## Set collection visibility
            intercol = bpy.data.collections.get('interpolation_tool')
            vl_col = bpy.context.view_layer.layer_collection.children.get(intercol.name)
            intercol.hide_viewport = vl_col.exclude = vl_col.hide_viewport = False

            ## Override collection
            col = intercol
        
        dg = bpy.context.evaluated_depsgraph_get()
        self.strokes_data = []

        for stroke_index, stroke in enumerate(strokes):
            stroke_data = []
            for point_index, point in enumerate(stroke.points):
                point_co_world = self.gp.matrix_world @ point.co

                if target_obj:
                    ## Object raycast
                    object_hit, hit_location, tri, tri_indices = obj_ray_cast(target_obj, point_co_world, origin, dg)
                else:
                    ## Scene raycast
                    object_hit, hit_location, tri, tri_indices = ray_cast_point(point_co_world, origin, dg)

                ## Iterative increasing search range when no surface hit
                if not object_hit:
                    object_hit, hit_location, tri, tri_indices = self.iterative_search(context, target_obj, point_co_world, origin, dg)

                    if not object_hit:
                        ## /!\ ERROR ! No surface found!
                        # For debugging, select only point.
                        bpy.ops.gpencil.select_all(action='DESELECT')
                        point.select = True
                        return self.exit(context, status='ERROR', text=f'Stroke {stroke_index} point {point_index} could not find underlying geometry')
                
                stroke_data.append((stroke, point_co_world, object_hit, hit_location, tri, tri_indices))

            self.strokes_data.append(stroke_data)

        if self.debug:
            self.scan_time = time()-self.start
            print(f'Scan time {self.scan_time:.4f}s')
        
        # Copy stroke selection
        bpy.ops.gpencil.select_linked() # Ensure whole stroke are selected before copy
        bpy.ops.gpencil.copy()

        context.window_manager.modal_handler_add(self)
        return {'RUNNING_MODAL'}

    def interpolate_frame(self, context):
        scn = context.scene
        origin = scn.camera.matrix_world.to_translation()
        plane_co, plane_no = get_gp_draw_plane(self.gp)
        bpy.ops.gpencil.select_all(action='DESELECT')
        bpy.ops.gpencil.paste(type='LAYER')

        if self.settings.method == 'BONE':
            ## Set plane on the bone
            plane_on_bone(self.settings.target_rig.pose.bones.get(self.settings.target_bone),
                                        arm=self.settings.target_rig,
                                        set_rotation=self.settings.use_bone_rotation,
                                        mesh=True)

        dg = bpy.context.evaluated_depsgraph_get()
        
        ## Get pasted stroke
        #new_strokes = [s for s in self.gp.data.layers.active.active_frame.strokes if s.select]
        new_strokes = [s for l in self.layers for s in l.active_frame.strokes if s.select]
        ## Keep reference to all accessible other strokes (in all accessible layer)
        other_strokes = [s for l in self.gp.data.layers if l.active_frame and not l.lock for s in l.active_frame.strokes if not s.select]
        
        smooth_level = self.settings.smooth_level

        occluded_points = []
        for new_stroke, stroke_data in zip(list(new_strokes), list(self.strokes_data)):
            world_co_3d = []
            for stroke, point_co, object_hit, hit_location, tri_a, tri_indices in stroke_data:
                eval_ob = object_hit.evaluated_get(dg)
                tri_b = [eval_ob.matrix_world @ eval_ob.data.vertices[i].co for i in tri_indices]

                new_loc = barycentric_transform(hit_location, *tri_a, *tri_b) 
                world_co_3d.append(new_loc)

            # Smooth points
            if smooth_level:
                old_co_3d = [s[1] for s in stroke_data]
                points_velocity = [b-a for a, b in zip(old_co_3d, world_co_3d)]

                # Average of points
                for i in range(smooth_level + 1):
                    points_velocity = [
                        (points_velocity[i] + points_velocity[i + 1]) / 2 if i == 0 else
                        (points_velocity[i] + points_velocity[i - 1]) / 2 if i == len(points_velocity) - 1 else
                        (points_velocity[i - 1] + points_velocity[i] + points_velocity[i + 1]) / 3
                        for i in range(len(points_velocity))
                    ]
                
                world_co_3d = [a+b for a, b in zip(old_co_3d, points_velocity)]

            ## Reproject on plane
            new_world_co_3d = [intersect_line_plane(origin, p, plane_co, plane_no) for p in world_co_3d]        
            new_local_co_3d = [co for coord in new_world_co_3d for co in self.gp.matrix_world.inverted() @ coord]
            new_stroke.points.foreach_set('co', new_local_co_3d)
            new_stroke.points.update()

            ## Occlusion management
            if self.settings.method == 'GEOMETRY' and self.settings.remove_occluded:
                for i, point in enumerate(new_stroke.points):
                    point_co = world_co_3d[i]
                    vec_direction = point_co - origin
                    ## Raycast with slightly reduced distance (avoid occlusion on initial surface)
                    n_hit, _, _, _, _, _ = scn.ray_cast(dg, origin, vec_direction, distance=vec_direction.length - 0.001)
                    if n_hit:
                        occluded_points.append(point)

        if occluded_points:
            ## Select only occluded point
            bpy.ops.gpencil.select_all(action='DESELECT')
            for point in occluded_points:
                point.select = True
            ## remove points
            bpy.ops.gpencil.delete(type='POINTS')

            ## restore selection (keep new strokes selected)
            bpy.ops.gpencil.select_all(action='SELECT')
            for stroke in other_strokes:
                stroke.select = False


classes = (
    GP_OT_interpolate_stroke_base,
    GP_OT_interpolate_stroke,
    GP_OT_pick_collection
)

def register():
    for c in classes:
        bpy.utils.register_class(c)
    
        
def unregister():
    for c in reversed(classes):
        bpy.utils.unregister_class(c)