#!/usr/bin/env python3

# requires python3-tk, python3-numpy
# Original script: https://raw.githubusercontent.com/reinderien/xcal/d0bcdd7c2dc02f8ba2ccf8ade14c2e202ca0c8fa/xcal

import re, os, sys
import numpy as np
from subprocess import run, PIPE, STDOUT

prop_name = 'Coordinate Transformation Matrix'
calibrate_file = '/usr/share/X11/xorg.conf.d/20-calibration.conf'

def xinput(*args):
    return cmd('/usr/bin/xinput', *args)

def cmd(*args):
    return run(args=args,
               stdout=PIPE, stderr=STDOUT, check=True,
               universal_newlines=True, env=os.environ).stdout

def read_cal(dev):
    stdout = xinput('--list-props', str(dev))
    line = re.search(prop_name + r'.*:\s+(\S.+)', stdout)
    if not line:
        print('Cal property not set; is this an xinput device?')
        exit(1)
    vals = np.matrix(line.group(1)).reshape(3, 3)

    print('Old calibration:')
    print(vals)
    print()
    return vals, np.linalg.inv(vals)

def ask(q):
    do = input(q + ' [y]: ')
    return (do or 'y').lower() == 'y'

def transform(x, y, cal):
    p = np.matrix([[x], [y], [1]])
    out = np.matmul(cal, p)
    return out.item(0), out.item(1)

def show_tk(n_points, old_cal_inv, new_cal=None):
    from tkinter import Tk, Canvas
    from math import ceil, sqrt

    root = Tk()
    X, Y = None, None
    root.attributes('-fullscreen', True)
    canvas = Canvas(root)

    def resize(event):
        nonlocal X, Y
        X, Y = event.width, event.height
        draw_legends()
        next_point()
    canvas.bind('<Configure>', resize)
    canvas.pack(expand=True, fill='both')

    legend_y = None

    def legend(text, colour='#000'):
        nonlocal legend_y
        canvas.create_text(X/2, legend_y, text=text, fill=colour)
        legend_y += 12

    def draw_legends():
        nonlocal legend_y
        legend_y = Y * 0.3
        if new_cal is not None:
            legend('TEST')
            legend('New cal point in green', '#0F0')
            legend('')

        legend('Esc to cancel')
        legend('Raw point in black')
        legend('Old cal point in blue', '#00F')
        legend('Target point in red', '#F00')

    point, points = {}, []
    index = -1
    n_cols = int(ceil(sqrt(n_points)))
    n_rows = int(ceil(n_points / n_cols))
    sensitive = False

    def next_point():
        nonlocal point, index, sensitive
        index += 1
        if index >= n_points:
            sensitive = False
            root.after(1000, root.destroy)
        else:
            sensitive = True
            x = 0.1 + 0.8*(index % n_cols)/(n_cols - 1)
            y = 0.1 + 0.8*(index // n_cols)/(n_rows - 1)
            point = {'sx': x, 'sy': y}

            draw_target(point['sx'], point['sy'])

    def cross(px, py, colour):
        x, y = px*X, py*Y
        canvas.create_line(x-10, y, x+10, y, fill=colour)
        canvas.create_line(x, y-10, x, y+10, fill=colour)

    def draw_target(px, py):
        x, y = px*X, py*Y
        canvas.create_oval(x-10, y-10, x+10, y+10, outline='#F00', width=3)
        cross(px, py, '#F00')

    def cancel_cal(_):
        print('Calibration cancelled')
        points.clear()
        root.destroy()
    root.bind('<Escape>', cancel_cal)
    canvas.bind('<Escape>', cancel_cal)

    def indicator(sx, sy, px, py, colour):
        canvas.create_line(X*sx, Y*sy, X*px, Y*py, fill=colour)
        cross(px, py, colour)

    def click(event):
        nonlocal sensitive
        if not sensitive:
            return
        sensitive = False
    
        sx, sy = point['sx'], point['sy']

        ox, oy = event.x/X, event.y/Y  # old-calibrated
        indicator(sx, sy, ox, oy, '#00F')

        ux, uy = transform(ox, oy, old_cal_inv)  # uncalibrated
        indicator(sx, sy, ux, uy, '#000')

        if new_cal is not None:
            nx, ny = transform(ux, uy, new_cal)  # new-calibrated (test only)
            indicator(sx, sy, nx, ny, '#0F0')

        point.update({'mx': ux, 'my': uy})
        points.append(point)

        canvas.after(500, next_point)
    canvas.bind('<Button-1>', click)

    root.mainloop()

    return points

def fit(screen_pts, mouse_pts):
    from math import log10
    m_screen = np.matrix([[*p, 1] for p in screen_pts])
    m_mouse = np.matrix([[*p, 1] for p in mouse_pts])
    m_transform, residuals, rank, singular = np.linalg.lstsq(m_mouse, m_screen)
    quality = -log10(residuals.sum())
    return m_transform, quality

def calibrate(points):
    '''
    m_mouse * m_transform = m_screen
    [mx my 1] [a b 0]   [sx sy 1]
    [mx my 1] [c d 0] = [sx sy 1]
    [...    ] [e f 1]   [...    ]
    '''
    m_transform, quality = fit(screen_pts=[(p['sx'], p['sy']) for p in points],
                                mouse_pts=[(p['mx'], p['my']) for p in points])
    m_transform[:, 2] = ([0], [0], [1])

    m_transform = m_transform.getT()
    return m_transform, quality

def get_devs():
    devs = {int(groups[1]): groups[0] for groups in
            re.findall(r'↳ (\w.+\w)\s+id=(\d+)\D+slave *pointer',
                       xinput('--list', '--short'))}
    if not devs:
        print('No suitable input devices found')
        exit(1)
    return devs

def print_devs(devs):
    print('Pointer devices:')
    print('%4s %35s' % ('ID', 'Name'))
    for i, name in sorted(devs.items()):
        print('%4d %35s' % (i, name))
    print()

def choose_preferred(devs):
    preferred = [i for (i, n) in devs.items() if 'touch' in n.lower()]
    if preferred:
        return preferred[0]
    return next(iter(devs.keys()))

def choose_dev(devs, preferred):
    while True:
        devstr = input('Device to calibrate [%d]: ' % preferred)
        if not devstr:
            return preferred
        try:
            dev = int(devstr)
        except ValueError:
            continue
        if dev in devs.keys():
            return dev

def use_cal(device, new_cal):
    cal_array = [str(x)+',' for x in new_cal.flatten().tolist()[0]]
    xinput('--set-prop', str(device), prop_name, *cal_array)

def save_cal(device, new_cal):
    cal_string = ' '.join([str(x) for x in new_cal.flatten().tolist()[0]])
    if os.path.exists(calibrate_file):
        return cmd('sudo', 'sed', '-i', 's/"TransformationMatrix" ".*"/"TransformationMatrix" "{0}"/g'.format(cal_string), calibrate_file)

    tmp_file = '/tmp/calibration.conf'
    with open(tmp_file, 'w') as f:
        f.write('''Section "InputClass"
    Identifier "{0} Calibration"
    MatchProduct "{0}"
    MatchDevicePath "/dev/input/event*"
    Driver "libinput"
    Option "TransformationMatrix" "{1}"
EndSection'''.format(device, cal_string))
    return cmd('sudo', 'mv', tmp_file, calibrate_file)

def main():
    if 'DISPLAY' not in os.environ:
        try: xinput()
        except: os.environ['DISPLAY'] = ':0.0'

    device = None
    devs = get_devs()

    if len(sys.argv) > 1:
        device = sys.argv[1].strip()
        if device not in devs.values():
            print("Input device not found: {0}".format(device))
            device = None

    if not device:
        print_devs(devs)
        preferred = choose_preferred(devs)
        device = choose_dev(devs, preferred)

    old_cal, old_cal_inv = read_cal(device)

    new_cal = None
    n_points = 4
    print()
    
    print("Please press each point on the display")
    points = show_tk(n_points, old_cal_inv)
    if points:
        new_cal, quality = calibrate(points)

        print('New calibration:')
        print(new_cal)
        print('Quality (should be at least 3): %.1f' % quality)
        print()

    print("Testing configuration...")
    show_tk(n_points, old_cal_inv, new_cal)

    if new_cal is not None and ask('Save calibration?'):
        use_cal(device, new_cal)
        save_cal(device, new_cal)

main()
