import bpy
from random import randint
from mathutils import Vector
from math import radians
from random import random as rand
import numpy as np
from bpy_extras.object_utils import world_to_camera_view as cam_space
import bmesh
from .utils import get_gp_draw_plane, link_vert,gp_stroke_to_bmesh,draw_gp_stroke,remapping


def to_bl_image(array, img):
    # Write the result to Blender preview
    width = len(array[0])
    height = len(array)


    image = bpy.data.images.get(img)
    if not image :
        image = bpy.data.images.new(img,width,height)

    image.generated_width = width
    image.generated_height = height

    output_pixels = []
    for y in range (0,height):
        for x in range(0,width):
            col = array[y][x]

            if not isinstance(col,list) :
                col =  [col]*3
            #print(col)

            output_pixels.append(col[0])
            output_pixels.append(col[1])
            output_pixels.append(col[2])
            output_pixels.append(1)

    image.pixels = output_pixels


def bm_angle_split(bm, angle) :
    bm.verts.ensure_lookup_table()
    loop = link_vert(bm.verts[0],[bm.verts[0]])
    splitted = []
    verts_to_split = [v for v in loop if len(v.link_edges) == 2 and v.calc_edge_angle() > radians(angle)]
    for i,v in enumerate(verts_to_split) :
        split_verts = bmesh.utils.vert_separate(v, v.link_edges)

        splitted.append(split_verts[0])

        if i == 0  :
            splitted.append(split_verts[1])

    bm.verts.ensure_lookup_table()

    if splitted :
        loops = []
        for v in splitted :
            loop = link_vert(v,[v])

            loops.append(loop)

    else :
        loops = [loop]

    return loops

def bm_uniform_density(bm, cam, max_spacing):
    from bpy_extras.object_utils import world_to_camera_view as cam_space
    scene = bpy.context.scene
    ratio = scene.render.resolution_y/scene.render.resolution_x
    for edge in bm.edges[:] :
        first = Vector(cam_space(scene,cam,edge.verts[0].co)[:-1])
        last = Vector(cam_space(scene,cam,edge.verts[1].co)[:-1])

        first[1]*= ratio
        last[1]*= ratio

        length = (last-first).length
        #print(length)
        if  length > max_spacing  :
            bmesh.ops.subdivide_edges(bm, edges = [edge],cuts = round(length/max_spacing)-1)

    return bm


def gp_stroke_angle_split (frame, strokes, angle):
    strokes_info = gp_stroke_to_bmesh(strokes)

    new_strokes = []
    for stroke_info in strokes_info :
        bm = stroke_info['bmesh']
        palette = stroke_info['color']
        line_width = stroke_info['line_width']
        strength = bm.verts.layers.float['strength']
        pressure =  bm.verts.layers.float['pressure']
        select = bm.verts.layers.int['select']

        splitted_loops = bm_angle_split(bm,angle)

        frame.strokes.remove(stroke_info['stroke'])
        for loop in splitted_loops :
            loop_info = [{'co':v.co,'strength': v[strength], 'pressure' :v[pressure],'select':v[select]} for v in loop]
            new_stroke = draw_gp_stroke(loop_info,frame,palette,width = line_width)
            new_strokes.append(new_stroke)

    return new_strokes


def gp_stroke_uniform_density(cam, frame, strokes, max_spacing):
    strokes_info = gp_stroke_to_bmesh(strokes)

    new_strokes = []

    for stroke_info in strokes_info :
        bm = stroke_info['bmesh'].copy()
        palette = stroke_info['color']
        line_width = stroke_info['line_width']
        strength = bm.verts.layers.float['strength']
        pressure =  bm.verts.layers.float['pressure']
        select = bm.verts.layers.int['select']

        bm_uniform_density(bm,cam,max_spacing)

        frame.strokes.remove(stroke_info['stroke'])
        bm.verts.ensure_lookup_table()

        loop = link_vert(bm.verts[0],[bm.verts[0]])
        loop_info = [{'co':v.co,'strength': v[strength], 'pressure' :v[pressure],'select':v[select]} for v in loop]

        new_stroke = draw_gp_stroke(loop_info,frame,palette,width = line_width)
        new_strokes.append(new_stroke)

    return new_strokes


def along_stroke(stroke, attr, length, min, max) :
    strokelen = len(stroke.points)
    for index,point in enumerate(stroke.points) :
        value = getattr(point,attr)
        if index < length :
            remap =  remapping(index/length,0,1,min,max)
            setattr(point,attr,value*remap)

        if index > strokelen-length :
            remap =  remapping((strokelen-index)/length,0,1,min,max)
            setattr(point,attr,value*remap)

def randomise_points(mat, points, attr, strength) :
    for point in points :
        if attr == 'co' :
            random_x = (rand()-0.5)
            random_y = (rand()-0.5)

            x = (random_x*strength, 0.0, 0.0)
            y = (0.0, random_y*strength, 0.0)

            point.co+= mat * Vector(x) - mat.to_translation()
            point.co+= mat * Vector(y) - mat.to_translation()

        else :
            value = getattr(point,attr)
            random = (rand()-0.5)
            setattr(point,attr,value+random*strength)


def zoom_to_object(cam, resolution, box, margin=0.01) :
    min_x= box[0]
    max_x= box[1]
    min_y= box[2]
    max_y= box[3]

    ratio = resolution[0]/resolution[1]

    zoom_cam = cam.copy()
    zoom_cam.data = zoom_cam.data.copy()

    center = ((max_x+min_x)/2,(max_y+min_y)/2)

    factor  = max((max_x-min_x),(max_y-min_y))+margin


    zoom_cam.data.shift_x += (center[0]-0.5)/factor
    zoom_cam.data.shift_y += (center[1]-0.5)/factor/ratio


    zoom_cam.data.lens /= factor

    bpy.context.scene.objects.link(zoom_cam)


    resolution = (int(resolution[0]*factor), int(resolution[1]*factor))


    scene = bpy.context.scene
    res_x = scene.render.resolution_x
    res_y =scene.render.resolution_y

    scene.render.resolution_x = resolution[0]
    scene.render.resolution_y = resolution[1]

    frame = zoom_cam.data.view_frame(scene)
    frame = [zoom_cam.matrix_world * corner for corner in frame]

    modelview_matrix =  zoom_cam.matrix_world.inverted().copy()
    projection_matrix = zoom_cam.calc_matrix_camera(resolution[0],resolution[1],1,1).copy()

    #bpy.data.cameras.remove(zoom_cam.data)
    #bpy.data.objects.remove(zoom_cam)
    #bpy.context.scene.objects.link(zoom_cam)

    scene.render.resolution_x = res_x
    scene.render.resolution_y = res_y
    #print(matrix,resolution)
    return modelview_matrix,projection_matrix,frame,resolution

# get object info
def get_object_info(mesh_groups, order_list = []) :
    scene = bpy.context.scene
    cam = scene.camera
    #scale = scene.render.resolution_percentage / 100.0
    res_x = int(scene.render.resolution_x)
    res_y = int(scene.render.resolution_y)

    scene.render.resolution_x = 1024
    scene.render.resolution_y = 1024

    cam_coord = cam.matrix_world.to_translation()

    convert_table = {(255,255,255):-1,(0,0,0):0}
    mesh_info = []
    color_index = 1
    for i,mesh_group in enumerate(mesh_groups) :
        for ob in mesh_group["objects"] :
            ob_info = {"object": ob, "materials" : [],"group_index" : i,'color_indexes':[]}

            namespace = mesh_group['namespace']
            ob_info['namespace'] = namespace

            l_name = ob.name
            if l_name.startswith(namespace+'_') :
                l_name = namespace+'_'+'COLO_'+ob.name.split('_',1)[1]
            else :
                l_name = namespace+'_'+'COLO_'+l_name

            ob_info['name'] = l_name

            bm = bmesh.new()
            bm.from_object(ob,scene)
            ob_info["bm"] = bm

            if not bm.verts : continue

            ob_info["matrix"] = ob.matrix_world

            if mesh_group.get("dupli_object") :
                ob_info["matrix"] = mesh_group["dupli_object"].matrix_world * ob.matrix_world


            global_bbox = [ob_info["matrix"] * Vector(v) for v in ob.bound_box]
            global_bbox_center = Vector(np.mean(global_bbox,axis =0))

            bbox_cam_space = [cam_space(scene,cam,p)[:-1] for p in global_bbox]

            sorted_x = sorted(bbox_cam_space,key = lambda x : x[0])
            sorted_y = sorted(bbox_cam_space,key = lambda x : x[1])


            ob_info['box_2d']=[sorted_x[0][0],sorted_x[-1][0],sorted_y[0][1],sorted_y[-1][1]]

            #print(ob_info['box_2d'])

            '''
            {
            'x' : int(sorted_x[0][0]*res_x)-1,
            'y' : int(sorted_y[0][1]*res_y)-1,
            'width' : int(sorted_x[-1][0]*res_x - sorted_x[0][0]*res_x)+1,
            'height' : int(sorted_y[-1][1]*res_y - sorted_y[0][1]*res_y)+1,
            }
            '''
            #bbox_depth = [Vector(p - cam_coord).length for p in global_bbox]
            #ob_info["depth"] = min(bbox_depth)
            ob_info["depth"] = Vector(global_bbox_center - cam_coord).length

            for slot in ob.material_slots  :
                mat = slot.material
                mat_info = {'index' : color_index}
                if mat :
                    color = [pow(v,1/2.2) for v in mat.diffuse_color]
                    name = mat.name
                else :
                    color = [1,0,1]
                    name = "default"

                #seed(i)
                random_color = (randint(0,255),randint(0,255),randint(0,255))

                if name.startswith(namespace+'_') :
                    name = namespace+'_'+'COLO_'+ name.split('_',1)[1]
                else :
                    name = namespace+'_'+'COLO_'+name

                mat_info["name"] = name
                mat_info["color"] = color
                mat_info["random_color"] = random_color

                ob_info["materials"].append(mat_info)
                ob_info["color_indexes"].append(color_index)

                convert_table[random_color] = color_index

                color_index +=1

            if not ob.material_slots  :
                random_color = (randint(0,255),randint(0,255),randint(0,255))
                ob_info["random_color"] = random_color
                ob_info["color"] = (0.5,0.5,0.5)
                ob_info["color_indexes"].append(color_index)
                convert_table[random_color] = color_index
                color_index +=1

            mesh_info.append(ob_info)


    mesh_info = sorted(mesh_info,key = lambda x : x['depth'],reverse=True)

    #print("###")
    #print([i['name'] for i in mesh_info])

    if order_list :
        for name in [i['name'] for i in mesh_info] :
            if name not in order_list :
                order_list.append(name)

        mesh_info = sorted(mesh_info,key = lambda x : order_list.index(x['name']))

    scene.render.resolution_x = res_x
    scene.render.resolution_y = res_y


    return mesh_info, convert_table

def redraw_ui() -> None:
    """Forces blender to redraw the UI."""
    for screen in bpy.data.screens:
        for area in screen.areas:
            area.tag_redraw()