# Run 'pip install matrix7' on the command line
# This dependency can be easily removed, since 2x2 matrix multiplication
# and transposition are the only operations needed
from matrix7 import Matrix, Vector
import decimal
from decimal import Decimal


####### THIS IS THE MAIN FUNCTION #######
# Returns the coefficients [a, b, c] of the polynomial
# 'f(x) = ax^2 + bx + c' for the specified palindrome
# (or None if the palindrome is impossible).
# 
# See http://andsol.org/2020/pol4pal.pdf
#
# Example usage:
#     coeffs = pol4pal([1, 2], 7)
#     if coeffs:
#         f = polynomial(coeffs)
#         print(f(2))
#         print(f(10))
#         ...
def pol4pal(firstPart, middle = None):

    # Palindrome length
    length = 2*len(firstPart) + (0 if middle == None else 1)
    
    # Palindrome of length 0
    if length == 0:
        return [1, 0, 1]

    # Palindrome of length 1
    elif length == 1:
        # Even
        if isEven(middle):
            return [middle**2 // 4, 1, 0]
        else: # Odd
            return [middle**2, 2, 0]

    # Particular case of length 3
    elif length == 3 and firstPart == [1] and isOdd(middle):
        m = (middle + 1) // 2
        k = (2*m+1)**2
        return [k, 2*(k - 1), k - 2]

    # Other cases
    else:
        return pol4palGeneralCase(firstPart, middle)

def pol4palGeneralCase(firstPart, middle):
    # Used to modify S in some cases
    Smod = Matrix([[1, 1],
                   [0, -1]])

    # Product of M(a) for a in firstPart
    # Starts with identity
    R = Matrix([[1, 0],
                [0, 1]])
    # Multiplies by each M
    for element in firstPart:
        R = R * M(element)
    
    # Palindrome of even length
    if middle == None:

        # Two odd numbers in the first row of R
        # => palindrome impossible for n natural
        r, s = R[0]
        if isOdd(r) and isOdd(s):
            return None
        
        t, u = absBezout(2*r*s, r**2 - s**2)
        S = Matrix([[r**2 - s**2, t],
                    [2*r*s      , u]])
        
        # First element of palindrome in cf(sqrt(t^2 + u^2))
        # is not equal to first element of input
        term = sqrtCfPalindromeFirstTerm(t**2 + u**2)
        if term != firstPart[0]:
            S = S * Smod
        StS = Matrix(S.raw).transpose() * S
        # s2: StS[0][1] == StS[1][0]
        (s1, s2), (s2, s3) = StS
        return [s1, 2*s2, s3]
    
    # Palindrome of odd length
    else:
        
        # Even central element
        if isEven(middle):
            # Multiplies R by last M
            R = R * M(middle // 2)
            # First row
            r, s = R[0]
            t, u = absBezout(s**2, r**2)
            S = Matrix([[r**2, t],
                        [s**2, u]])
            term = sqrtCfPalindromeFirstTerm(t*u)
            
        # Odd central element
        else:
            # Multiplies R by N
            R = R * N(middle)
            # First row
            r, s = R[0]
            if isEven(r) and isEven(s):
                return None
            t, u = absBezout(s**2, r**2)
            S = Matrix([[r**2, 2*t],
                        [s**2, 2*u]])
            term = sqrtCfPalindromeFirstTerm(2*t * 2*u)
        
        # First of the palindrome in cf(sqrt(t*u)) is first of input?
        if term != firstPart[0]:
            S = S * Smod

        # Product of the two rows of S * [x; 1]
        (s1, s2), (s3, s4) = S
        return [s1*s3, s1*s4 + s2*s3, s2*s4]


# First term of the palidrome in the continued fraction of sqrt(n)
def sqrtCfPalindromeFirstTerm(n):
    # How many decimal places we want
    nextDp = len(str(n))
    # The precision must also include the number of digits of the
    # integer part of the sqrt(n)
    sqrtIntegerPartLength = (nextDp + 1) // 2
    
    # This loop checks if the result is right
    # (Add/subtract 1 in the last decimal place of frac;
    # if the result changes in either case, term may be
    # wrong due to rounding)
    gotResult = False
    while not gotResult:
        # Some extra decimal places to avoid rounding errors
        # ('+ 3' isn't enough yet to avoid errors in all cases --
        # ex.: pol4pal([104554254])
        dp = nextDp + 3
        decimal.getcontext().prec = sqrtIntegerPartLength + dp

        # To check the result (grain = 1 in the last decimal place)
        grain = Decimal(f'1e-{dp}')
        nextDp = 2*nextDp
        
        # Square root of n
        sqrt = Decimal(n).sqrt()
        # Fractional part
        frac = sqrt - sqrt.to_integral_value(decimal.ROUND_DOWN)
        # First term of the palindrome in the continued fraction
        term = (1 / frac).to_integral_value(decimal.ROUND_DOWN)
        
        # If false, term may be wrong due to rounding
        gotResult = (
            (1 / (frac + grain)).to_integral_value(decimal.ROUND_DOWN) == term and
            (1 / (frac - grain)).to_integral_value(decimal.ROUND_DOWN) == term
        )
    
    # Back to int
    return int(term)


def M(n):
    return Matrix([[n, 1],
                   [1, 0]])

def N(n):
    return Matrix([[n, 1],
                   [2, 0]])

# Bézout's identity
def bezout(a, b):
    if b == 0:
        return [1, 0]
    div = a // b
    rem = a % b
    t2, u2 = bezout(b, rem)
    t, u = u2, t2 - u2*div
    return [t, u]

# Absolute values of Bézout's identity, avoiding zeros
def absBezout(a, b):
    t, u = bezout(a, b)
    if t == 0 or u == 0:
        t = t + b
        u = u - a
    return [abs(t), abs(u)]

def isEven(n):
    return n % 2 == 0
def isOdd(n):
    return not isEven(n)

# Polynomial with the given coefficients
def polynomial(coeffs):
    a, b, c = coeffs
    return lambda x: a*x**2 + b*x + c
