#!/usr/bin/env python3
## display in a more readable way the end of an energy.job file from xshells.

import numpy as np
import re
import sys
import subprocess
from shutil import get_terminal_size

HELP="""
xspeek prints the end of an energy.* file to the console, formatted in an easily readable way.
If gnuplot is available, it will also display a crude time-evolution of energy to the console.
If several files are given as arguments, it will display the quantities of each file at their end time.
Example:
    xspeek energy.bench
"""

def get_endlines(filename, nlines=5, firstlines=2):
    f = open(filename, 'rb')
    header = f.readline()   # header
    while len(header) < 3:  header = f.readline()  # skip empty lines

    data = []
    for i in range(0, firstlines):
        l = f.readline()        # first line of data
        while l.startswith(b'%'): l = f.readline()
        data.append( np.array(l.split(), dtype=float) )

    f.seek(0,2);  sze = f.tell()    # file size
    ofs = min(nlines*2000, sze)
    f.seek(-ofs, 2)
    lb = f.readlines()
    while len(lb) < nlines and ofs < sze:
        ofs = min(3*ofs, sze)
        f.seek(-ofs, 2)
        lb = f.readlines()
        k=0
        while k<len(lb):
            if lb[k].startswith('%'):  del lb[k]
            else: k+=1
    if len(lb) < nlines:
        nlines = len(lb)-1
    lb = lb[-nlines:]
    for l in lb:
        data.append( np.array(l.split(), dtype=float) )
    data = np.array(data)

    field=[]
    if header is not None:        # there is a header
        header = header.decode().lstrip('% \n')        # the header
        RE=re.compile(r'[,\s\t]+')  # regex to split up fields
        field=RE.split(header)[0:-1]

    return(data, field)


def print_transpose(data, keys, prec=8):
    nlines = data.shape[0]
    lenk = 0
    for k in keys:
        lenk = max(lenk, len(k))

    fmth = "%%%ds " % (lenk+1)
    fmtd = "%%%d.%dg " % (prec+7,prec)
    fmts = fmth + fmtd*nlines

    for i in range(0,len(keys)):
        if any(data[:,i] != 0.0):   # supress lines with all zero
            print(fmts % ((keys[i],) + tuple(reversed(data[:,i]))))


def print_normal(data, keys, prec=5, lnames=[]):
    nlines = data.shape[0]

    # find entries that are non-zero
    ilist = []
    keys2 = []
    for i in range(0,data.shape[1]):
        if any(data[:,i] != 0.0):
            ilist.append(i)
            keys2.append(keys[i])
    # keep only those entries
    data = data[:, ilist]
    keys = keys2

    lenk = prec+6
    for k in keys:
        lenk = max(lenk, len(k))

    lenh = 0
    if len(lnames) == 0:
        lnames = nlines * [""]
    else:
        for n in lnames:
            lenh = max(lenh, len(n)+1)

    ncol = get_terminal_size().columns // (max(lenh,lenk)+1)

    fmth0 = "%%-%ds" % lenh
    fmth = "%%-%ds " % lenk
    fmtd = "%%-%d.%dg " % (lenk, prec)
    i=0
    while i < len(keys):
        if i>0:  print("")
        nc = min(ncol, len(keys)-i)
        if nc <= 0:
            break
        print( (fmth0 + fmth*nc) % (("",) + tuple(keys[i:i+nc])) )     # print chunck of keys as header
        for l in range(0,nlines):
            print( (fmth0 + fmtd*nc) % ((lnames[l],) + tuple(data[l,i:i+nc])) )
        i+=nc

def plot_txt(fname, keys, idx=[1,], logscale=True):
    try:
        gnuplot = subprocess.Popen(["gnuplot"], stdin=subprocess.PIPE)
    except:
        return
    gnuplot.stdin.write(b"set term dumb size %d 28 ansi\n" % get_terminal_size().columns)   # color output with: 'ansi' or 'ansi256'. b&w ouput with 'mono' or nothing.
    gnuplot.stdin.write(b"set style data lines\n")
    gnuplot.stdin.write(b"set tics nomirror scale 0.3\n")
    gnuplot.stdin.write(b"set key above\n")
    gnuplot.stdin.write(b"set colorsequence classic\n")
    if logscale:  gnuplot.stdin.write(b"set logscale y 10\n")
    if isinstance(fname, str):   fname = [fname, ]      # fname must be a list
    if len(fname) > 1:     # loop on file names:
        gnuplot.stdin.write(b"plot '%s' using 1:%d title '%s:%s'" % (fname[0].encode(), idx[0]+1, fname[0].encode(), keys[idx[0]].encode()))
        for k in range(1,len(fname)):
            gnuplot.stdin.write(b", '%s' using 1:%d title '%s:%s'" % (fname[k].encode(), idx[0]+1, fname[k].encode(), keys[idx[0]].encode()))
    else:
        gnuplot.stdin.write(b"plot '%s' using 1:%d title '%s'" % (fname[0].encode(), idx[0]+1, keys[idx[0]].encode()))
        for k in range(1,len(idx)):
            gnuplot.stdin.write(b", '' using 1:%d title '%s'" % (idx[k]+1, keys[idx[k]].encode()))
    gnuplot.stdin.write(b"\nquit\n")
    gnuplot.stdin.flush()
    gnuplot.wait()

def print_field_header(fname):
    import struct
    f = open(fname, "rb")
    bswap = False
    ### read header
    head = f.read(1024)
    head_pattern = '16i8d64s832s'
    h = struct.unpack('=' + head_pattern, head)     # try native order first
    if (h[5]<0) or ((h[2]+1)*(h[1]+1) - ((h[2]*h[3])*(h[2]+1))/2) != h[4]:   # test nr, lmax,mmax,mres, nlm
        bswap = True
        h = struct.unpack(('>' if (struct.pack('=h',1)==struct.pack('<h',1)) else '<') + head_pattern, head)    # swap endian
    nr, irs, ire = h[5], h[6], h[7]
    r = np.fromfile(f, dtype=float, count=nr)    # read radial grid
    dtyp = 'complex64' if (h[0] & 4096) else 'complex128'    # single or double precision
    if (h[0] & (4096*4)):  dtyp = 'fp48'  # fp 48 compression
    if bswap:
        r = r.byteswap()
    info = {'lmax':h[1], 'mmax':h[2], 'mres':h[3], 'nlm':h[4], 'nr':nr, 'ir':(irs,ire), 'r':(r[irs],r[ire]),
     'version':h[0]&4095, 'BC':(h[8],h[9]), 'ncomp':h[13], 'iter':h[14], 'step':h[15], 'time':h[19], 'shtnorm':h[12],
     'id':h[24].decode().strip('\x00'), 'dtype':dtyp, 'varltr':(h[0]&8192==8192), 'bswap':bswap, 'par':h[25].decode().strip('\x00')}
    f.close()
    print(info)

if len(sys.argv) < 2 or sys.argv[1]=='--help' or sys.argv[1]=='-h':
    print(HELP)
    quit()

# FIRST FILE
try:
    data, keys = get_endlines(sys.argv[1])
except:
    print_field_header(sys.argv[1])
    quit()

if len(sys.argv) == 2:
    llist = [0,1,-3,-2,-1]
    if len(data) <= len(llist):
        llist = range(0,len(data))
    #print_transpose(data[llist,:], keys)
    #print("")
    print_normal(data[llist], keys)
    plot_txt(sys.argv[1], keys,[1,2,3,4])   # ascii-plot of energies

else:   # SEVERAL FILES (assuming same diags !!!)
    f_list = [sys.argv[1],]
    data = data[-1,:]
    for f in sys.argv[2:]:
        if f not in f_list:
            data2, keys2 = get_endlines(f,1,0)
            if len(keys2) == len(keys):
                data = np.vstack((data, data2[-1,:]))
                f_list.append(f)
            else:
                print("WARNING: skipping '%s' with %d columns" % (f, len(keys2)) )

    print_normal(data, keys, lnames=f_list)
    plot_txt(f_list, keys, [1,] if any(data[:,1] != 0) else [2,])   ## ascii-plot of kinetic or magnetic energy of all files
