# This is a CamiTK python action
#
# This action displaces the points of a MeshComponent randomly 
# (this is the equivalent of the Shaker tutorial)

import camitk
import numpy as np
from PySide2.QtCore import QTimer

def tick(self:camitk.Action):
    # get the point position as a numpy array
    points = self.mesh.getPointSetAsNumpy();
    percent = self.getParameterValue("Intensity")
    
    # check dimensions
    assert points.ndim == 2 and points.shape[1] == 3, "Input must be an (N, 3) array"
    assert 0 <= percent <= 100, "Percent must be between 0 and 100"

    # Compute bounding box
    min_pt = points.min(axis=0)
    max_pt = points.max(axis=0)
    diag = np.linalg.norm(max_pt - min_pt)
    move_distance = (percent / 100.0) * diag
    
    # Generate random directions per point
    random_vectors = np.random.randn(*points.shape)
    norms = np.linalg.norm(random_vectors, axis=1, keepdims=True)
    directions = random_vectors / norms  # Normalize each row

    # Scale and move
    step_displacements = directions * move_distance
    moved_points = points + step_displacements
        
    # update mesh
    self.mesh.replacePointSet(moved_points)
    
    # compute displacements
    if self.getParameterValue("Show Displacement"):
        total_displacement = np.linalg.norm(moved_points - self.initial_position[self.mesh.getName()], axis=1)
        self.mesh.addPointData("Displacement", total_displacement)
    else:
        self.mesh.setDataRepresentationOff()
    
    # update interval
    interval = self.getParameterValue("Update Intervals")
    self.timer.setInterval(interval)
    
def init(self:camitk.Action):
    self.timer = None
    self.setApplyButtonText("Start Shaking")
    self.initial_position = {}    

def process(self:camitk.Action):
    self.mesh = self.getTargets()[-1]
    
    # store initial position
    if self.mesh.getName() not in self.initial_position:
        self.initial_position[self.mesh.getName()] = self.mesh.getPointSetAsNumpy();
    
    if not self.timer:
        self.timer = QTimer()
        connected = self.timer.timeout.connect(lambda: tick(self))     

    if not self.timer.isActive():
        self.setApplyButtonText("Stop Shaking")
        interval = self.getParameterValue("Update Intervals")
        self.timer.start(interval)
    else:
        self.mesh.replacePointSet(self.initial_position[self.mesh.getName()])
        self.setApplyButtonText("Restart Shaking")
        self.timer.stop()
    
def parameterChanged(self:camitk.Action, name:str):
    if not self.getParameterValue("Show Displacement"):
        self.mesh.setDataRepresentationOff()