import bpy
import mathutils
from mathutils import Matrix, Vector
from math import pi
import numpy as np
from time import time
from . import utils
from mathutils.geometry import intersect_line_plane

def get_scale_matrix(scale):
    # recreate a neutral mat scale
    matscale_x = Matrix.Scale(scale[0], 4,(1,0,0))
    matscale_y = Matrix.Scale(scale[1], 4,(0,1,0))
    matscale_z = Matrix.Scale(scale[2], 4,(0,0,1))
    matscale = matscale_x @ matscale_y @ matscale_z
    return matscale

'''
## Old reproject method using Operators:
omode = bpy.context.mode

if all_strokes:
    layers_state = [[l, l.hide, l.lock, l.lock_frame] for l in obj.data.layers]
    for l in obj.data.layers:
        l.hide = False
        l.lock = False
        l.lock_frame = False
bpy.ops.object.mode_set(mode='EDIT_GPENCIL')


for fnum in frame_list:
    bpy.context.scene.frame_current = fnum
    bpy.ops.gpencil.select_all(action='SELECT')
    bpy.ops.gpencil.reproject(type=proj_type) # 'INVOKE_DEFAULT'
    bpy.ops.gpencil.select_all(action='DESELECT')

# restore
if all_strokes:
    for layer, hide, lock, lock_frame in layers_state:
        layer.hide = hide
        layer.lock = lock
        layer.lock_frame = lock_frame

bpy.ops.object.mode_set(mode=omode)
'''

def batch_reproject(obj, proj_type='VIEW', all_strokes=True, restore_frame=False):
    '''Reproject - ops method
    :all_stroke: affect hided, locked layers
    '''

    if restore_frame:
        oframe = bpy.context.scene.frame_current
    
    plan_co, plane_no = utils.get_gp_draw_plane(obj, orient=proj_type)

    frame_list = [f.frame_number for l in obj.data.layers for f in l.frames if len(f.strokes)]
    frame_list = list(set(frame_list))
    frame_list.sort()

    scn = bpy.context.scene
    for i in frame_list:
        scn.frame_set(i) # refresh scene
        # scn.frame_current = i # no refresh

        origin = scn.camera.matrix_world.to_translation()
        matrix_inv = obj.matrix_world.inverted()
        # origin = np.array(scn.camera.matrix_world.to_translation(), 'float64')
        # matrix = np.array(obj.matrix_world, dtype='float64')
        # matrix_inv = np.array(obj.matrix_world.inverted(), dtype='float64')
        #mat = src.matrix_world
        for l in obj.data.layers:
            if not all_strokes:
                if not l.select:
                    continue
                if l.hide or l.lock:
                    continue
            f = next((f for f in l.frames if f.frame_number == i), None)
            if f is None:
                continue
            for s in f.strokes:
                ## Batch matrix apply (Here is slower than list comprehension).
                # nb_points = len(s.points)
                # coords = np.empty(nb_points * 3, dtype='float64')
                # s.points.foreach_get('co', coords)
                # world_co_3d = utils.matrix_transform(coords.reshape((nb_points, 3)), matrix)

                ## list comprehension method
                world_co_3d = [obj.matrix_world @ p.co for p in s.points]

                new_world_co_3d = [intersect_line_plane(origin, p, plan_co, plane_no) for p in world_co_3d]
                
                ## Basic method (Slower than foreach_set)
                # for i, p in enumerate(s.points):
                #     p.co = obj.matrix_world.inverted() @ new_world_co_3d[i]

                ## Ravel new coordinate on the fly
                new_local_coords = [axis for p in new_world_co_3d for axis in matrix_inv @ p]

                ## Set points in obj local space (apply matrix slower)
                # new_local_coords = utils.matrix_transform(new_world_co_3d, matrix_inv).ravel()
                s.points.foreach_set('co', new_local_coords)

    
    if restore_frame:
        bpy.context.scene.frame_current = oframe

    ## Update the layer and redraw all viewports
    obj.data.layers.update()
    utils.refresh_areas()

def align_global(reproject=True, ref=None, all_strokes=True):

    if not ref:
        ref = bpy.context.scene.camera

    o = bpy.context.object
    # if o.matrix_basis != o.matrix_world and not o.parent:

    ref = bpy.context.scene.camera
    ref_mat = ref.matrix_world
    ref_loc, ref_rot, ref_scale = ref_mat.decompose()

    if o.parent:
        mat = o.matrix_world
    else:
        mat = o.matrix_basis

    o_loc, o_rot, o_scale = mat.decompose()

    mat_90 = Matrix.Rotation(-pi/2, 4, 'X')

    loc_mat = Matrix.Translation(o_loc)
    rot_mat = ref_rot.to_matrix().to_4x4() @ mat_90
    scale_mat = get_scale_matrix(o_scale)

    new_mat = loc_mat @ rot_mat @ scale_mat

    # world_coords = []
    for l in o.data.layers:
        for f in l.frames:
            for s in f.strokes:
                ## foreach
                coords = [p.co @ mat.inverted() @ new_mat for p in s.points]
                # print('coords: ', coords)
                # print([co for v in coords for co in v])
                s.points.foreach_set('co', [co for v in coords for co in v])
                # s.points.update() # seem to works # but adding/deleting a point is "safer"
                ## force update
                s.points.add(1)
                s.points.pop()

                # for p in s.points:
                    ## GOOD :
                    # world_co = mat @ p.co
                    # p.co = new_mat.inverted() @ world_co

                    ## GOOD :
                    # p.co = p.co @ mat.inverted() @ new_mat

    if o.parent:
        o.matrix_world = new_mat
    else:
        o.matrix_basis = new_mat

    if reproject:
        batch_reproject(o, proj_type='FRONT', all_strokes=all_strokes)


def align_all_frames(reproject=True, ref=None, all_strokes=True):
    
    print('aligning all frames...')

    o = bpy.context.object
    if not ref:
        ref = bpy.context.scene.camera

    # get all rot
    chanel = 'rotation_quaternion' if o.rotation_mode == 'QUATERNION' else 'rotation_euler'

    ## double list keys
    rot_keys = [int(k.co.x) for fcu in o.animation_data.action.fcurves for k in fcu.keyframe_points if fcu.data_path == chanel]

    ## normal iter
    # for fcu in o.animation_data.action.fcurves:
    #     if fcu.data_path != chanel :
    #         continue
    #     for k in fcu.keyframe_points():
    #         rot_keys.append(k.co.x)

    rot_keys = list(set(rot_keys))

    # TODO # TOTHINK
    # for now the rotation of the object is adjusted at every frame....
    # might be better to check camera rotation of the current frame only (stored as copy).
    # else the object rotate following the cameraview vector (not constant)...

    mat_90 = Matrix.Rotation(-pi/2, 4, 'X')

    for l in o.data.layers:
        for f in l.frames:
            # set the frame to dedicated
            bpy.context.scene.frame_set(f.frame_number)

            ref_mat = ref.matrix_world
            ref_loc, ref_rot, ref_scale = ref_mat.decompose()

            if o.parent:
                mat = o.matrix_world
            else:
                mat = o.matrix_basis

            o_loc, o_rot, o_scale = mat.decompose()
            loc_mat = Matrix.Translation(o_loc)
            rot_mat = ref_rot.to_matrix().to_4x4() @ mat_90
            scale_mat = get_scale_matrix(o_scale)
            new_mat = loc_mat @ rot_mat @ scale_mat

            for s in f.strokes:
                ## foreach
                coords = [p.co @ mat.inverted() @ new_mat for p in s.points]
                # print('coords: ', coords)
                # print([co for v in coords for co in v])
                s.points.foreach_set('co', [co for v in coords for co in v])
                # s.points.update() # seem to works
                ## force update
                s.points.add(1)
                s.points.pop()

    for fnum in rot_keys:
        bpy.context.scene.frame_set(fnum)
        #/update calculation block
        ref_mat = ref.matrix_world
        ref_loc, ref_rot, ref_scale = ref_mat.decompose()

        if o.parent:
            mat = o.matrix_world
        else:
            mat = o.matrix_basis

        o_loc, o_rot, o_scale = mat.decompose()
        loc_mat = Matrix.Translation(o_loc)
        rot_mat = ref_rot.to_matrix().to_4x4() @ mat_90
        scale_mat = get_scale_matrix(o_scale)
        new_mat = loc_mat @ rot_mat @ scale_mat
        # update calculation block/

        if o.parent:
            o.matrix_world = new_mat
        else:
            o.matrix_basis = new_mat

        o.keyframe_insert(chanel, index=-1, frame=bpy.context.scene.frame_current, options={'INSERTKEY_AVAILABLE'})


    if reproject:
        batch_reproject(o, proj_type='FRONT', all_strokes=all_strokes)

    return


class GPTB_OT_realign(bpy.types.Operator):
    bl_idname = "gp.realign"
    bl_label = "Realign GP"
    bl_description = "Realign the grease pencil front axis with active camera"
    bl_options = {"REGISTER"}

    @classmethod
    def poll(cls, context):
        return context.object and context.object.type == 'GPENCIL'

    reproject : bpy.props.BoolProperty(
        name='Reproject', default=True,
        description='Reproject stroke on the new alignment')
    
    all_strokes : bpy.props.BoolProperty(
        name='All Strokes', default=True,
        description='Hided and locked layer will also be reprojected')

    set_draw_axis : bpy.props.BoolProperty(
        name='Set draw axis to Front', default=True,
        description='Set the gpencil draw plane axis to Front')
    ## add option to bake strokes if rotation anim is not constant ? might generate too many Keyframes

    def invoke(self, context, event):
        if context.object.data.use_multiedit:
            self.report({'ERROR'}, 'Does not work in Multiframe mode')
            return {"CANCELLED"}
        
        self.alert = ''
        o = context.object
        if o.animation_data and o.animation_data.action:
            act = o.animation_data.action
            for chan in ('rotation_euler', 'rotation_quaternion'):
                if act.fcurves.find(chan):
                    self.alert = 'Animated Rotation (CONSTANT interpolation)'
                    interpos = [p for fcu in act.fcurves if fcu.data_path == chan for p in fcu.keyframe_points if p.interpolation != 'CONSTANT']
                    if interpos:
                        self.alert = f'Animated Rotation ! ({len(interpos)} key not constant)'
                    break

        return context.window_manager.invoke_props_dialog(self, width=450)

    def draw(self, context):
        layout = self.layout
        layout.label(text='Realign the GP object : Front axis facing active camera')
        if self.alert:
            layout.label(text=self.alert, icon='ERROR')
            layout.label(text='(rotations key will be overwritten to face camera)')
        
        # layout.separator()
        box = layout.box()
        box.prop(self, "reproject")
        if self.reproject:
            box.label(text='After Realigning, reproject each frames on front axis')
            if not context.region_data.view_perspective == 'CAMERA':
                box.label(text='Not in camera ! (reprojection is made from view)', icon='ERROR')
            box.prop(self, "all_strokes")
            if not self.all_strokes:
                box.label(text='Only visible and unlocked layers will be reprojected', icon='INFO')

        axis = context.scene.tool_settings.gpencil_sculpt.lock_axis
        if axis != 'AXIS_Y':
            orient = {
                'VIEW'  : ['View', 'RESTRICT_VIEW_ON'],
                # 'AXIS_Y': ['front (X-Z)', 'AXIS_FRONT'], # 
                'AXIS_X': ['side (Y-Z)', 'AXIS_SIDE'],
                'AXIS_Z': ['top (X-Y)', 'AXIS_TOP'],
                'CURSOR': ['Cursor', 'PIVOT_CURSOR'],
                }
            
            box = layout.box()
            box.label(text=f'Current drawing plane : {orient[axis][0]}', icon=orient[axis][1])
            box.prop(self, "set_draw_axis")

        
    def exit(self, context, frame):
        context.scene.frame_current = frame
        if context.scene.tool_settings.gpencil_sculpt.lock_axis != 'AXIS_Y' and self.set_draw_axis:
            context.scene.tool_settings.gpencil_sculpt.lock_axis = 'AXIS_Y'

    def execute(self, context):
        t0 = time()
        oframe = context.scene.frame_current

        o = bpy.context.object
        if o.animation_data and o.animation_data.action:
            if o.animation_data.action.fcurves.find('rotation_euler') or o.animation_data.action.fcurves.find('rotation_quaternion'):
                align_all_frames(reproject=self.reproject)
                print(f'\nAnim realign ({time()-t0:.2f}s)')
                self.exit(context, oframe)
                return {"FINISHED"}

        align_global(reproject=self.reproject)        
        print(f'\nGlobal Realign ({time()-t0:.2f}s)')
        self.exit(context, oframe)
        return {"FINISHED"}


class GPTB_OT_batch_reproject_all_frames(bpy.types.Operator):
    bl_idname = "gp.batch_reproject_all_frames"
    bl_label = "Reproject All Frames"
    bl_description = "Reproject all frames of active object."
    bl_options = {"REGISTER"}

    @classmethod
    def poll(cls, context):
        return context.object and context.object.type == 'GPENCIL'
    
    all_strokes : bpy.props.BoolProperty(
        name='All Strokes', default=True,
        description='Hided and locked layer will also be reprojected')
 
    type : bpy.props.EnumProperty(name='Type',
        items=(('CURRENT', "Current", ""),
               ('FRONT', "Front", ""),
               ('SIDE', "Side", ""),
               ('TOP', "Top", ""),
               ('VIEW', "View", ""),
               ('CURSOR', "Cursor", ""),
               # ('SURFACE', "Surface", ""),
               ),
        default='CURRENT')

    def invoke(self, context, event):
        if context.object.data.use_multiedit:
            self.report({'ERROR'}, 'Does not work in Multi-edit')
            return {"CANCELLED"}
        return context.window_manager.invoke_props_dialog(self)

    def draw(self, context):
        layout = self.layout
        if not context.region_data.view_perspective == 'CAMERA':
            # layout.label(text='Not in camera ! (reprojection is made from view)', icon='ERROR')
            layout.label(text='Reprojection is made from camera', icon='ERROR')
        layout.prop(self, "all_strokes")
        layout.prop(self, "type", text='Project Axis')

        ## Hint show axis
        if self.type == 'CURRENT':
            ## Show as prop
            # row = layout.row()
            # row.prop(context.scene.tool_settings.gpencil_sculpt, 'lock_axis', text='Current', icon='INFO')
            # row.enabled = False
            
            orient = {
                'VIEW'  : ['View', 'RESTRICT_VIEW_ON'],
                'AXIS_Y': ['front (X-Z)', 'AXIS_FRONT'], # AXIS_Y
                'AXIS_X': ['side (Y-Z)', 'AXIS_SIDE'], # AXIS_X
                'AXIS_Z': ['top (X-Y)', 'AXIS_TOP'], # AXIS_Z
                'CURSOR': ['Cursor', 'PIVOT_CURSOR'],
                }
            box = layout.box()
            axis = context.scene.tool_settings.gpencil_sculpt.lock_axis
            box.label(text=orient[axis][0], icon=orient[axis][1])
    

        

    def execute(self, context):
        t0 = time()
        orient = self.type
        if self.type == 'CURRENT':
            orient = None

        batch_reproject(context.object, proj_type=orient, all_strokes=self.all_strokes, restore_frame=True)

        self.report({'INFO'}, f'Reprojected in ({time()-t0:.2f}s)' )

        return {"FINISHED"}

### -- MENU ENTRY --

def reproject_clean_menu(self, context):
    if context.mode == 'EDIT_GPENCIL':
        self.layout.operator_context = 'INVOKE_REGION_WIN' # needed for popup (also works with 'INVOKE_DEFAULT')
        self.layout.operator('gp.batch_reproject_all_frames', icon='KEYTYPE_JITTER_VEC')

def reproject_context_menu(self, context):
    if context.mode == 'EDIT_GPENCIL' and context.scene.tool_settings.gpencil_selectmode_edit == 'STROKE':
        self.layout.operator_context = 'INVOKE_REGION_WIN' # needed for popup
        self.layout.operator('gp.batch_reproject_all_frames', icon='KEYTYPE_JITTER_VEC')

classes = (
GPTB_OT_realign,
GPTB_OT_batch_reproject_all_frames,
)

def register():
    for cl in classes:
        bpy.utils.register_class(cl)

    bpy.types.VIEW3D_MT_gpencil_edit_context_menu.append(reproject_context_menu)
    bpy.types.GPENCIL_MT_cleanup.append(reproject_clean_menu)

def unregister():
    bpy.types.GPENCIL_MT_cleanup.remove(reproject_clean_menu)
    bpy.types.VIEW3D_MT_gpencil_edit_context_menu.remove(reproject_context_menu)
    
    for cl in reversed(classes):
        bpy.utils.unregister_class(cl)