import bpy
import numpy as np
from time import perf_counter, time
from mathutils import Vector, Matrix

from gp_interpolate.utils import (matrix_transform,
                                  plane_on_bone,
                                  ray_cast_point,
                                  obj_ray_cast,
                                  intersect_with_tesselated_plane,
                                  triangle_normal,
                                  search_square,
                                  get_gp_draw_plane,
                                  create_plane,
                                  following_keys,
                                  index_list_from_bools,
                                  attr_set)

from mathutils.geometry import (barycentric_transform,
                                intersect_point_tri,
                                intersect_point_line,
                                intersect_line_plane,
                                tessellate_polygon)


## Working non-modal operator
## cannot be cancelled once animation is launched
## advantage of "with statement" to reset state in case of error

class GP_OT_interpolate_stroke(bpy.types.Operator):
    bl_idname = "gp.interpolate_stroke"
    bl_label = "Interpolate Stroke"
    bl_description = 'Interpolate Stroke'
    bl_options = {'REGISTER', 'UNDO'} 

    @classmethod
    def poll(cls, context):
        if context.active_object and context.object.type == 'GPENCIL'\
            and context.mode in ('EDIT_GPENCIL', 'SCULPT_GPENCIL', 'PAINT_GPENCIL'):
            return True
        cls.poll_message_set("Need a Grease pencil object in Edit or Sculpt mode")
        return False

    @classmethod
    def description(cls, context, properties):
        if properties.next:
            return f"Interpolate Stroke Forward"
        else:
            return f"Interpolate Stroke Backward"

    next : bpy.props.BoolProperty(name='Next', default=True, options={'SKIP_SAVE'})

    def execute(self, context):
        debug=False
        settings = context.scene.gp_interpo_settings
        scn = bpy.context.scene

        # auto_key_status = context.tool_settings.use_keyframe_insert_auto
        # context.tool_settings.use_keyframe_insert_auto = True

        ## Determine on what key/keys to jump
        frames_to_jump = following_keys(forward=self.next, animation=settings.use_animation)
        if not len(frames_to_jump):
            self.report({'WARNING'}, 'No keyframe available in this direction')
            return {'CANCELLED'}
        # print('frames_to_jump: ', frames_to_jump)
        
        gp = context.object

        # matrix = np.array(gp.matrix_world, dtype='float64')
        # origin = np.array(scn.camera.matrix_world.to_translation(), 'float64')
        matrix = gp.matrix_world
        origin = scn.camera.matrix_world.to_translation()
        
        col = settings.target_collection
        if not col:
            col = scn.collection

        # print('----')

        tgt_strokes = [s for l in self.layers for s in l.active_frame.strokes if s.select]
        
        ## If nothing selected in sculpt/paint, Select all before triggering
        if not tgt_strokes and context.mode in ('SCULPT_GPENCIL', 'PAINT_GPENCIL'):
            for s in gp.data.layers.active.active_frame.strokes:
                s.select = True
            tgt_strokes = gp.data.layers.active.active_frame.strokes

        if not tgt_strokes:            
            self.report({'ERROR'}, 'No stroke selected !')
            return {'CANCELLED'}

        included_cols = [c.name for c in gp.users_collection]
        target_obj = None
        start = time()
        
        if settings.method == 'BONE':
            if not settings.target_rig or not settings.target_bone:
                self.report({'ERROR'}, 'No Bone selected')
                return {'CANCELLED'}

            included_cols.append('interpolation_tool')

            ## ensure collection and plane exists
            # get/create collection
            toolcol = bpy.data.collections.get('interpolation_tool')
            if not toolcol:
                toolcol = bpy.data.collections.new('interpolation_tool')

            if toolcol.name not in bpy.context.scene.collection.children:
                bpy.context.scene.collection.children.link(toolcol)
                toolcol.hide_viewport = True # needed ?
            
            # get/create meshplane
            plane = bpy.data.objects.get('interpolation_plane')
            if not plane:
                plane = create_plane(name='interpolation_plane')
                plane.select_set(False)

            if plane.name not in toolcol.objects:
                toolcol.objects.link(plane)
            target_obj = plane

        elif settings.method == 'GEOMETRY':
            if col != context.scene.collection:
                included_cols.append(col.name)
            ## Maybe include a plane just behind geo ? probably bad idea
        
        elif settings.method == 'OBJECT':
            if not settings.target_object:
                self.report({'ERROR'}, 'No Object selected')
                return {'CANCELLED'}
            col = scn.collection # Reset collection filter
            target_obj = settings.target_object
            if target_obj.library:
                ## Look if an override exists in scene to use instead of default object
                if (override := next((o for o in scn.objects if o.name == target_obj.name and o.override_library), None)):
                    target_obj = override

        ## Prepare context manager
        store_list = [
            # (context.view_layer.objects, 'active', gp),
            (context.tool_settings, 'use_keyframe_insert_auto', True),
            # (bpy.context.scene.render, 'simplify_subdivision', 0),
            ]
        
        # TODO: for now, the collection filter is not used at all in GEOMETRY mode
        # it can be used to hide collection for faster animation mode

        if settings.method == 'BONE':
            ## TEST: Add collections containing rig (cannot be excluded)
            # rig_parent_cols = [c.name for c in scn.collection.children_recursive if settings.target_rig.name in c.all_objects]
            # included_cols += rig_parent_cols
            for vlc in context.view_layer.layer_collection.children:
                store_list.append(
                    # (vlc, 'exclude', vlc.name not in included_cols), # If excluded rig does not update !
                    (vlc, 'hide_viewport', vlc.name not in included_cols), # viewport viz
                )
        
        # print(f'Preparation {time()-start:.4f}s')

        with attr_set(store_list):
            if settings.method == 'BONE':
                ## replace plane
                _bone_plane = plane_on_bone(settings.target_rig.pose.bones.get(settings.target_bone),
                                            arm=settings.target_rig,
                                            set_rotation=settings.use_bone_rotation,
                                            mesh=True)

                ## Set collection visibility
                intercol = bpy.data.collections.get('interpolation_tool')
                vl_col = bpy.context.view_layer.layer_collection.children.get(intercol.name)
                intercol.hide_viewport = vl_col.exclude = vl_col.hide_viewport = False

                # Override collection
                col = intercol
            
            dg = bpy.context.evaluated_depsgraph_get()
            strokes_data = []

            for si, stroke in enumerate(tgt_strokes):
                nb_points = len(stroke.points)

                local_co = np.empty(nb_points * 3, dtype='float64')
                stroke.points.foreach_get('co', local_co)
                # local_co_3d = local_co.reshape((nb_points, 3))
                world_co_3d = matrix_transform(local_co.reshape((nb_points, 3)), matrix)

                stroke_data = []
                for i, point in enumerate(stroke.points):
                    point_co_world = world_co_3d[i]
                    if target_obj:
                        object_hit, hit_location, tri, tri_indices = obj_ray_cast(target_obj, Vector(point_co_world), origin, dg)
                    else:
                        # scene raycast
                        object_hit, hit_location, tri, tri_indices = ray_cast_point(point_co_world, origin, dg)

                    ## Increasing search range
                    if not object_hit: # or object_hit not in col.all_objects[:]:
                        found = False
                        for iteration in range(1, 6):
                            for square_co in search_square(point_co_world, factor=settings.search_range * iteration):

                                if target_obj:
                                    object_hit, hit_location, tri, tri_indices = obj_ray_cast(target_obj, Vector(square_co), origin, dg)
                                else:
                                    # scene raycast
                                    object_hit, hit_location, tri, tri_indices = ray_cast_point(point_co_world, origin, dg)

                                if object_hit: # and object_hit in col.all_objects[:]:
                                    hit_location = intersect_line_plane(origin, point_co_world, tri[0], triangle_normal(*tri))
                                    found = True
                                    # print(f'{si}:{i} iteration {iteration}') # Dbg
                                    break

                            if found:
                                break

                        if not found:
                            ## /!\ ERROR ! No surface found!
                            # For debugging, select only problematic stroke (and point)
                            for sid, s in enumerate(tgt_strokes):
                                s.select = sid == si
                                for ip, p in enumerate(stroke.points):
                                    p.select = ip == i
                            self.report({'ERROR'}, f'Stroke {si} point {i} could not find underlying geometry')
                            return {'CANCELLED'}
                    
                    stroke_data.append((stroke, point_co_world, object_hit, hit_location, tri, tri_indices))

                strokes_data.append(stroke_data)

            if debug:
                scan_time = time()-start
                print(f'Scan time {scan_time:.4f}s')
            
            # Copy stroke selection
            bpy.ops.gpencil.copy()

            # Jump frame and paste
            wm = bpy.context.window_manager # Pgs
            wm.progress_begin(frames_to_jump[0], frames_to_jump[-1]) # Pgs

            for f in frames_to_jump:
                wm.progress_update(f) # Pgs
                scn.frame_set(f)
                origin = scn.camera.matrix_world.to_translation()
                # origin = np.array(scn.camera.matrix_world.to_translation(), 'float64')
                plan_co, plane_no = get_gp_draw_plane(gp)
                bpy.ops.gpencil.paste(type="LAYER")

                if settings.method == 'BONE':
                    bone_plane = plane_on_bone(settings.target_rig.pose.bones.get(settings.target_bone),
                                                arm=settings.target_rig,
                                                set_rotation=settings.use_bone_rotation,
                                                mesh=True)
                    
                dg = bpy.context.evaluated_depsgraph_get()
                matrix_inv = np.array(gp.matrix_world.inverted(), dtype='float64')#.inverted()
                new_strokes = [(l, s) for l in self.layers for s in l.active_frame.strokes if s.select]

                # for new_stroke, stroke_data in zip(new_strokes, strokes_data):
                for (layer, new_stroke), stroke_data in zip(reversed(new_strokes), reversed(strokes_data)):
                    world_co_3d = []
                    for stroke, point_co, object_hit, hit_location, tri_a, tri_indices in stroke_data:
                        eval_ob = object_hit.evaluated_get(dg)
                        tri_b = [eval_ob.data.vertices[i].co for i in tri_indices]
                        tri_b = matrix_transform(tri_b, eval_ob.matrix_world)

                        new_loc = barycentric_transform(hit_location, *tri_a, *tri_b)
                        # try:
                        #     new_loc = barycentric_transform(hit_location, *tri_a, *tri_b)
                        # except Exception as e:
                        #     print(f'\nCould not apply barycentric tranform {eval_ob.name}')
                        #     print(e)
                        #     # bpy.context.scene.cursor.location = hit_location
                        #     self.report({'ERROR'}, f'Stroke {si} point {i} could not find underlying geometry')
                        #     return {'CANCELLED'}
                            
                        world_co_3d.append(new_loc)

                    ## Test with point in 3D space (Debug)
                    # nb_points = len(new_stroke.points)
                    # new_stroke.points.foreach_set('co', np.array(world_co_3d).reshape(nb_points*3))
                    # new_stroke.points.update()

                    ## Reproject on plane
                    new_world_co_3d = [intersect_line_plane(origin, p, plan_co, plane_no) for p in world_co_3d]        
                    new_local_co_3d = matrix_transform(new_world_co_3d, matrix_inv)

                    nb_points = len(new_stroke.points)
                    new_stroke.points.foreach_set('co', new_local_co_3d.reshape(nb_points*3))
                    new_stroke.points.update()

                    
                    ## Occlusion management
                    if settings.method == 'GEOMETRY' and settings.remove_occluded:
                        viz_list = [True]*len(world_co_3d)
                        for i, nco in enumerate(world_co_3d):
                            vec_direction = nco - origin
                            ## Reduced distance slightly to avoid occlusion on same source...
                            dist = vec_direction.length - 0.001
                            n_hit, _hit_location, _normal, _n_face_index, n_object_hit, _matrix = scn.ray_cast(dg, origin, vec_direction, distance=dist)
                            # if there is a hit, it's occluded
                            if n_hit:
                                viz_list[i] = False

                        if all(viz_list):
                            # All visible, do nothing (just keep previous stroke)
                            continue

                        if any(viz_list):
                            # Create sub-strokes according to indices in original stroke
                            for sublist in index_list_from_bools(viz_list):
                                ## Clear if only one isolated point ?
                                if len(sublist) == 1:
                                    continue

                                ns = layer.active_frame.strokes.new()
                                for elem in ('hardness', 'material_index', 'line_width'):
                                    setattr(ns, elem, getattr(new_stroke, elem))

                                ns.points.add(len(sublist))
                                for i, point_index in enumerate(sublist):
                                    for elem in ('uv_factor', 'uv_fill', 'uv_rotation', 'pressure', 'co', 'strength', 'vertex_color'):
                                        setattr(ns.points[i], elem, getattr(new_stroke.points[point_index], elem))

                        ## Delete original stroke
                        layer.active_frame.strokes.remove(new_stroke)

            wm.progress_end() # Pgs

        if debug:
            print(f"Paste'n'place time {time()-start - scan_time}s")
        else:
            if settings.method == 'BONE':
                ## Remove Plane and it's collection after use
                bpy.data.objects.remove(plane)
                bpy.data.collections.remove(col)

        if len(frames_to_jump) > 1:
            self.report({'INFO'}, f'{len(frames_to_jump)} interpolated frame(s) ({time()-start:.3f}s)')

        # print('Done')
        return {'FINISHED'}


classes = (
    GP_OT_interpolate_stroke,
)

def register():
    for c in classes:
        bpy.utils.register_class(c)
    
        
def unregister():
    for c in reversed(classes):
        bpy.utils.unregister_class(c)