# -*- coding: utf-8 -*-
"""
A module to apply/compute a global correction to PPMXL proper motions
as described in 2010AJ....139.2440R.

The upstream version is at https://github.com/johnjvickers/PPMXL_Correction.

This version adds a healpix-based interpolation scheme to speed up
computation.  It is what actually produced the corrected proper motions.
Deviations from the exact values are negligible against the errors in
proper motion.

See the FastRecenterer class.
"""


import numpy as np
import pyfits
import os
from numpy.lib import recfunctions as rfs
from scipy.special import sph_harm
import healpy as hp
from math import modf
from scipy.interpolate import interp1d

import csv

# sudo apt-get install -y python-pip python-pyfits python-scipy python-numpy \
#                                  python-matplotlib pkg-config libcfitsio-dev
# pip install --user healpy

INSTALL_DIR = os.environ["GAVO_INPUTS"]+'/ppmxl/johnspms'

def fixpath(fname):
	return os.path.join(INSTALL_DIR, fname)


#fits
ra_fit = np.genfromtxt(fixpath('pmr.csv'), delimiter=',')
de_fit = np.genfromtxt(fixpath('pmd.csv'), delimiter=',')

def read_fit(f_name):
	# Note fit[shell][zeroth column is magnitude so we cut it out]
	return np.ndfromtxt(os.path.join(INSTALL_DIR, f_name), delimiter=',')[:, 1:]

fit = (read_fit('pmr.csv'), read_fit('pmd.csv'))

#These files are 7 lines long, the first column is the magnitude slice it was
#fit to. After that the columns are leading coefficients for:
#sph(order, degree) = (0,0), (0,1), (1,1)_real, (1,1)_imag, (0,2), (1,2)_real,
#										(1,2)_imag, (2,2)_real, (2,2)_imag, (2,2)_real, (0,3)...
#we do not use negative orders as they add no new information
#the zeroth orders never have imaginary information
#121 coefficients per shell.

#These coefficients could change by submission.

#just in case I want to change how many harmonics are fit.
harm_degree = 8


DEG2RAD = np.pi / 180.
RAD2DEG = 180 / np.pi


def surface_harmonics(harmonics, popt):
	"""returns raw corrections for a point with the spherical harmonics
	harmonics (a dict as returned by get_harmonics).

	popt is in the order discussed above.
	"""

	coeff_count = 0

	# the result is the sum of all the sph harmonics of the soln in this spot
	res = 0
	#this goes to 11 because we use the first 10 degrees of sph_harmonics
	#plus zeroth.
	for degree in range(harm_degree + 1):
		#we only use the positive orders because the negatives are just inverses
		#of the positives
		for order in range(degree + 1):

			#Find intrinsic value of spherical harmonic
			plot_harm = harmonics[degree, order]

			#And apply relevant fit coefficient
			if order == 0:
				res += popt[coeff_count] * plot_harm.real
				coeff_count += 1
			else:
				res += popt[coeff_count] * plot_harm.real
				coeff_count += 1
				res += popt[coeff_count] * plot_harm.imag
				coeff_count += 1

	#Return estimated offset
	return( res )


def get_harmonics(ra, dec):
	"""returns a dictionary (order, degree) of the spherical harmonics
	for the given position.
	"""
	colatitude = (90-dec)*DEG2RAD
	longitude = ra*DEG2RAD

	res = {}
	for degree in range(harm_degree + 1):
		for order in range(degree + 1):
			res[degree,order] = sph_harm(order, degree, longitude, colatitude)
	return res


def recenter(ra, dec, jmag):
	"""returns the corrections to PPMXL pmd and pma based on a position
	in degrees and a J magnitude (in mag).

	The corrections are in mas/yr.
	"""
	#the data used in the fits are in shells 0.1 mags wide and half a mag
	#apart [14.0-14.1, 14.5-14.6, ...]
	#that's the basis for these cuts and stuff

	harmonics = get_harmonics(ra, dec)

	#fringe cases
	if jmag < 14.05: #assume first fit
		#find sum of all orders and degrees of sph harmonics at given point
		#multiplied by the fit coefficients for each order and degree

		#Note fit[shell][zeroth column is magnitude so we cut it out]
		pmr_corr = surface_harmonics(harmonics, ra_fit[0][1:])
		pmd_corr = surface_harmonics(harmonics, de_fit[0][1:])
		return(pmr_corr, pmd_corr)


	if jmag > 17.05: #assume last fit
		pmr_corr = surface_harmonics(harmonics, ra_fit[-1][1:])
		pmd_corr = surface_harmonics(harmonics, de_fit[-1][1:])
		return(pmr_corr, pmd_corr)

	for i in range(len(ra_fit)):
		#find the shells which enclose this particular j
		if ra_fit[i][0]+0.05 >= jmag:

			#linearly interpolate between the prior shell fit and the next
			#based on the j magnitude
			travel_frac = (jmag - (ra_fit[i-1][0] + 0.05)) / 0.5

			pmr_corr_up = surface_harmonics(harmonics, ra_fit[i][1:])
			pmr_corr_down = surface_harmonics(harmonics, ra_fit[i-1][1:])

			pmd_corr_up = surface_harmonics(harmonics, de_fit[i][1:])
			pmd_corr_down = surface_harmonics(harmonics, de_fit[i-1][1:])

			pmr_corr = pmr_corr_down + travel_frac*( pmr_corr_up - pmr_corr_down )
			pmd_corr = pmd_corr_down + travel_frac*( pmd_corr_up - pmd_corr_down )

			return( pmr_corr, pmd_corr )


def hp_harmonics(hp_angle):
	"""returns a dictionary (order, degree) of the spherical harmonics
	for the position given as a Healpix angle.
	"""
	res = {}
	for degree in range(harm_degree + 1):
		for order in range(degree + 1):
			res[degree, order] = sph_harm(order, degree,
													hp_angle[1], hp_angle[0])
	return res


def recenter_bin(i, hp_angle):
	"""returns the corrections to PPMXL pmd and pma based on the index i
	for a J magnitude and a position (as healpix angle).

	The corrections are in mas/yr.
	"""
	harmonics = hp_harmonics(hp_angle)

	# find sum of all orders and degrees of sph harmonics at given point
	# multiplied by the fit coefficients for each order and degree

	return complex(*tuple(surface_harmonics(harmonics, fit[q][i])
															for q in range(2)))


class FastRecenterer(object):
	"""A wrapper around a healpix-based cache to speed up the computation
	of the corrections.

	Construct with the order (the default 6 should be fine given the
	precision of the whole thing).  Then call get_correction with
	ra, dec (in degrees) and jmag.  Results are in mas/yr.
	"""
	def __init__(self, order = 6):
		# The indices to the cache self.hp_maps are (jmag_bin, healpix_index)
		bin_count = fit[0].shape[0]
		nside = hp.order2nside(order)
		npix = hp.nside2npix(nside)
		self.hp_maps = np.ndarray((bin_count, npix), complex)
		for i in range(bin_count):
			self.hp_maps[i] = np.fromiter((recenter_bin(i, hp.pix2ang(nside, n))
										for n in range(npix)), complex, npix)
	
	def get_correction(self, ra, dec, jmag):
		theta = (90 - dec) * DEG2RAD
		phi = ra * DEG2RAD

		# fringe cases
		jmag = max(14.05, jmag) # assume first fit
		jmag = min(jmag, 17.049999999) # assume last fit

		# the data used in the fits are in shells 0.1 mags wide and half a mag
		# apart [14.0-14.1, 14.5-14.6, ...]
		# that's the basis for these cuts and stuff

		# find the shells which enclose this particular jmag
		frac, i = modf((jmag - 14.05) * 2)
		i = int(i)
		
		# linearly interpolate between the prior shell fit and the next
		# based on the j magnitude
		corr = interp1d([0, 1], list(hp.get_interp_val(self.hp_maps[k], theta, phi)
													for k in range(i, i + 2)))(frac)
		return tuple(np.array([corr]).view(float))



def _test(fname):
	#run correction on a sample of ICRF defining radio sources

	data_pre = pyfits.getdata(fname)
	#get rid of faulty j measurements
	print len(data_pre)
	data_pre = data_pre[~np.isnan(data_pre['jmag'])]
	print len(data_pre)




	print(
		'Original Average P.M. in R.A.: ' + str(np.average(data_pre['pmr_mas'])) +
		' mas/yr.'
	)

	print(
		'Original Average P.M. in Dec.: ' + str(np.average(data_pre['pmd_mas'])) +
		' mas/yr.'
	)





	#the corrected proper motions [could be vectorized I guess]
	pmr_new = []
	pmd_new = []
	print('Recentering the data.')
	for star in data_pre:

		#correction for each
		pmr_corr, pmd_corr = recenter(
			star['raj2000'], star['dej2000'], star['jmag']
		)

		pmr_new.append(star['pmr_mas'] - pmr_corr)
		pmd_new.append(star['pmd_mas'] - pmd_corr)

	pmr_new = np.array(pmr_new)
	pmd_new = np.array(pmd_new)





	print(
		'Corrected Average P.M. in R.A.: ' + str(np.average(pmr_new)) + ' mas/yr.'
	)

	print(
		'Corrected Average P.M. in Dec.: ' + str(np.average(pmd_new)) + ' mas/yr.'
	)




def _test2():

	infile = 'vc_qso.fit'
	outfile = 'vc_qso_recentered.fit'

	data_pre = pyfits.getdata(infile)

	pmr_new, pmd_new = [], []
	for star in data_pre:
		#correction for each
		pmr_corr, pmd_corr = recenter(
			star['ra'], star['de'], star['jmag']
		)

		pmr_new.append(star['pmr_mas'] - pmr_corr)
		pmd_new.append(star['pmd_mas'] - pmd_corr)

	data_post = rfs.append_fields(
		data_pre, names=['pmr_corr_mas', 'pmd_corr_mas'],
		data=[pmr_new, pmd_new], dtypes=['>f4','>f4'], usemask=False
		)

	pyfits.writeto(outfile, data_post, clobber=True)


def csv_out(file_name, order, header, x):
	my_file = open(file_name + '_' + str(order) + '.csv', 'wb')
	wr = csv.writer(my_file)
	wr.writerow(header)
	wr.writerows(x)
	my_file.close()

def hp_recenter_test(order, test_size = 350000, constant_jmag = True):

	hp_cache = FastRecenterer(order)

	# collect relative differences of Healpix interpolation and recenter()
	nall = []
	lall = []
	for i in range(test_size):
		args = (360 * np.random.rand(),
				90 - RAD2DEG * np.arccos(2 * np.random.rand() - 1),
				15.5 if constant_jmag else 13 + 5 * np.random.rand())
		x = np.array(recenter(*args))
		y = np.array(hp_cache.get_correction(*args))

		diff = abs((y - x) / x)
		nall.append(list(args) + sum([t.tolist() for t in diff, x, y], []))
		lall.append(list(args) + sum([t.tolist() for t in np.log10(diff), x, y],
																			[]))
	n_header = ['RA', 'DEC', 'jmag', 'rel_dev_ra', 'rel_dev_dec',
				'corr_ra_johns', 'corr_dec_johns', 'corr_ra_hp', 'corr_dec_hp']
	l_header = n_header
	l_header[3:5] = ['log_rel_dev_ra', 'log_rel_dev_dec']
	csv_out('hp_all', order, n_header, nall)
	csv_out('hp_log_all', order, l_header, lall)



if __name__ == '__main__':

	hp_recenter_test(6) # takes about 45 minutes on an i5-750
	# hp_recenter_test(8)

	# # sample of Healpix interpolation as an alternative to recenter():
 	# print recenter(2, 1, 16)
	# hp_cache = FastRecenterer()
	# print hp_cache.get_correction(2, 1, 16)

	# _test2()

	# print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
	# print('icrf vcs sources')
	# _test('icrf_vcs.fit')
	# print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')
	# print('icrf non vcs sources')
	# _test('icrf_non_vcs.fit')
