import bpy
import numpy as np
from time import perf_counter, time
from mathutils import Vector, Matrix

from gp_interpolate.utils import (matrix_transform,
                                  plane_on_bone,
                                  ray_cast_point,
                                  intersect_with_tesselated_plane,
                                  triangle_normal,
                                  search_square,
                                  get_gp_draw_plane,
                                  create_plane,
                                  following_keys,
                                  attr_set)


from mathutils.geometry import (barycentric_transform,
                                intersect_point_tri,
                                intersect_point_line,
                                intersect_line_plane,
                                tessellate_polygon)

## /!\ Old code kept for testing
## use pseudo plane coordinate instead of rayvast on real mesh plane


class GP_OT_interpolate_stroke_simple(bpy.types.Operator):
    bl_idname = "gp.interpolate_stroke_simple"
    bl_label = "Interpolate Stroke Simple"
    bl_description = 'Interpolate Stroke Simple'
    bl_options = {'REGISTER', 'UNDO'} 

    @classmethod
    def poll(cls, context):
        if context.active_object and context.object.type == 'GPENCIL'\
            and context.mode in ('EDIT_GPENCIL', 'SCULPT_GPENCIL', 'PAINT_GPENCIL'):
            return True
        cls.poll_message_set("Need a Grease pencil object in Edit or Sculpt mode")
        return False

    @classmethod
    def description(cls, context, properties):
        if properties.next:
            return f"Interpolate Stroke Forward"
        else:
            return f"Interpolate Stroke Backward"

    next : bpy.props.BoolProperty(name='Next', default=True, options={'SKIP_SAVE'})

    def execute(self, context):
        settings = context.scene.gp_interpo_settings

        auto_key_status = context.tool_settings.use_keyframe_insert_auto
        context.tool_settings.use_keyframe_insert_auto = True

        ## Determine on what key to jump
        frames_to_jump = following_keys(forward=self.next)
        if not frames_to_jump:
            self.report({'WARNING'}, 'No keyframe available in this direction')
            return {'CANCELLED'}
        
        frames_to_jump = frames_to_jump[0]

        gp = context.object

        scn = bpy.context.scene
        dg = bpy.context.evaluated_depsgraph_get()
        matrix = np.array(gp.matrix_world, dtype='float64')#.inverted()
        col = settings.target_collection
        if not col:
            col = scn.collection

        origin = np.array(scn.camera.matrix_world.to_translation(), 'float64')
        # print('----')

        tgt_strokes = [s for s in gp.data.layers.active.active_frame.strokes if s.select]
        
        ## If nothing selected in sculpt/paint, Select all before triggering
        if not tgt_strokes and context.mode in ('SCULPT_GPENCIL', 'PAINT_GPENCIL'):
            for s in gp.data.layers.active.active_frame.strokes:
                s.select = True
            tgt_strokes = gp.data.layers.active.active_frame.strokes

        if not tgt_strokes:            
            self.report({'ERROR'}, 'No stroke selected !')
            return {'CANCELLED'}
        
        strokes_data = []
        if settings.method == 'BONE':
            ## Follow Bone method (Full WIP)
            if not settings.target_rig or not settings.target_bone:
                self.report({'ERROR'}, 'No Bone selected')
                return {'CANCELLED'}

            bone_plane = plane_on_bone(settings.target_rig.pose.bones.get(settings.target_bone),
                                        arm=settings.target_rig,
                                        set_rotation=settings.use_bone_rotation)

            strokes_data = []
            for stroke in tgt_strokes:
                nb_points = len(stroke.points)
                local_co = np.empty(nb_points * 3, dtype='float64')
                stroke.points.foreach_get('co', local_co)
                # local_co_3d = local_co.reshape((nb_points, 3))
                world_co_3d = matrix_transform(local_co.reshape((nb_points, 3)), matrix)

                stroke_data = []

                for i, point in enumerate(stroke.points):
                    point_co_world = world_co_3d[i]
                    hit_location, tri, tri_indices = intersect_with_tesselated_plane(point_co_world, origin, bone_plane)
                    ## Probably easier to just generate a single vast triangle and use it
                    
                    ## Store same as other method (without object hit)
                    stroke_data.append((stroke, point_co_world, hit_location, tri, tri_indices))
                
                strokes_data.append(stroke_data)

        else:
            ## Geometry method
            for stroke in tgt_strokes:
                nb_points = len(stroke.points)

                local_co = np.empty(nb_points * 3, dtype='float64')
                stroke.points.foreach_get('co', local_co)
                # local_co_3d = local_co.reshape((nb_points, 3))
                world_co_3d = matrix_transform(local_co.reshape((nb_points, 3)), matrix)

                stroke_data = []
                for i, point in enumerate(stroke.points):
                    point_co_world = world_co_3d[i]
                    
                    object_hit, hit_location, tri, tri_indices = ray_cast_point(point_co_world, origin, dg)
                    if not object_hit or object_hit not in col.all_objects[:]:
                        for square_co in search_square(point_co_world, factor=settings.search_range):
                            object_hit, hit_location, tri, tri_indices = ray_cast_point(square_co, origin, dg)
                            if object_hit and object_hit in col.all_objects[:]:
                                
                                hit_location = intersect_line_plane(origin, point_co_world, tri[0], triangle_normal(*tri))
                                
                                break
                    
                    stroke_data.append((stroke, point_co_world, object_hit, hit_location, tri, tri_indices))
                
                strokes_data.append(stroke_data)

        # Copy stroke selection, jump frame and paste

        bpy.ops.gpencil.copy()

        scn.frame_set(frames_to_jump)

        plan_co, plane_no = get_gp_draw_plane(gp)
        
        bpy.ops.gpencil.paste()



        if settings.method == 'BONE':
            matrix_inv = np.array(gp.matrix_world.inverted(), dtype='float64')#.inverted()
            new_strokes = gp.data.layers.active.active_frame.strokes[-len(strokes_data):]

            bone_plane = plane_on_bone(settings.target_rig.pose.bones.get(settings.target_bone),
                                        arm=settings.target_rig,
                                        set_rotation=settings.use_bone_rotation)

            for new_stroke, stroke_data in zip(new_strokes, strokes_data):
                world_co_3d = [] # np.array(len()dtype='float64')#np.
                for stroke, point_co, hit_location, tri_a, tri_indices in stroke_data:
                    tri_b = [bone_plane[i] for i in tri_indices]
                    # tri_b = matrix_transform(tri_b, settings.target_rig.matrix_world)
                    ## rotate tri_b by bone differential angle camera's aim axis ?

                    new_loc = barycentric_transform(hit_location, *tri_a, *tri_b)
                    world_co_3d.append(new_loc)

    
                # Reproject on plane
                new_world_co_3d = [intersect_line_plane(origin, p, plan_co, plane_no) for p in world_co_3d]        
                new_local_co_3d = matrix_transform(new_world_co_3d, matrix_inv)
                
                nb_points = len(new_stroke.points)
                new_stroke.points.foreach_set('co', new_local_co_3d.reshape(nb_points*3))
                new_stroke.points.update()

        else:
            dg = bpy.context.evaluated_depsgraph_get()
            
            matrix_inv = np.array(gp.matrix_world.inverted(), dtype='float64')#.inverted()
            new_strokes = gp.data.layers.active.active_frame.strokes[-len(strokes_data):]

            for new_stroke, stroke_data in zip(new_strokes, strokes_data):
                world_co_3d = [] # np.array(len()dtype='float64')#np.
                for stroke, point_co, object_hit, hit_location, tri_a, tri_indices in stroke_data:
                    eval_ob = object_hit.evaluated_get(dg)
                    tri_b = [eval_ob.data.vertices[i].co for i in tri_indices]
                    tri_b = matrix_transform(tri_b, eval_ob.matrix_world)
                    
                    new_loc = barycentric_transform(hit_location, *tri_a, *tri_b)
                    world_co_3d.append(new_loc)

                # Reproject on plane
                new_world_co_3d = [intersect_line_plane(origin, p, plan_co, plane_no) for p in world_co_3d]        
                new_local_co_3d = matrix_transform(new_world_co_3d, matrix_inv)
                
                nb_points = len(new_stroke.points)
                new_stroke.points.foreach_set('co', new_local_co_3d.reshape(nb_points*3))
                new_stroke.points.update()

        ## Reset autokey status
        context.tool_settings.use_keyframe_insert_auto = auto_key_status

        return {'FINISHED'}

classes = (
    GP_OT_interpolate_stroke_simple,
)

def register():
    for c in classes:
        bpy.utils.register_class(c)
    
        
def unregister():
    for c in reversed(classes):
        bpy.utils.unregister_class(c)