# Run 'pip install matrix7' on the command line # This dependency can be easily removed, since 2x2 matrix multiplication # and transposition are the only operations needed from matrix7 import Matrix, Vector import decimal from decimal import Decimal ####### THIS IS THE MAIN FUNCTION ####### # Returns the coefficients [a, b, c] of the polynomial # 'f(x) = ax^2 + bx + c' for the specified palindrome # (or None if the palindrome is impossible). # # See http://andsol.org/2020/pol4pal.pdf # # Example usage: # coeffs = pol4pal([1, 2], 7) # if coeffs: # f = polynomial(coeffs) # print(f(2)) # print(f(10)) # ... def pol4pal(firstPart, middle = None): # Palindrome length length = 2*len(firstPart) + (0 if middle == None else 1) # Palindrome of length 0 if length == 0: return [1, 0, 1] # Palindrome of length 1 elif length == 1: # Even if isEven(middle): return [middle**2 // 4, 1, 0] else: # Odd return [middle**2, 2, 0] # Particular case of length 3 elif length == 3 and firstPart == [1] and isOdd(middle): m = (middle + 1) // 2 k = (2*m+1)**2 return [k, 2*(k - 1), k - 2] # Other cases else: return pol4palGeneralCase(firstPart, middle) def pol4palGeneralCase(firstPart, middle): # Used to modify S in some cases Smod = Matrix([[1, 1], [0, -1]]) # Product of M(a) for a in firstPart # Starts with identity R = Matrix([[1, 0], [0, 1]]) # Multiplies by each M for element in firstPart: R = R * M(element) # Palindrome of even length if middle == None: # Two odd numbers in the first row of R # => palindrome impossible for n natural r, s = R[0] if isOdd(r) and isOdd(s): return None t, u = absBezout(2*r*s, r**2 - s**2) S = Matrix([[r**2 - s**2, t], [2*r*s , u]]) # First element of palindrome in cf(sqrt(t^2 + u^2)) # is not equal to first element of input term = sqrtCfPalindromeFirstTerm(t**2 + u**2) if term != firstPart[0]: S = S * Smod StS = Matrix(S.raw).transpose() * S # s2: StS[0][1] == StS[1][0] (s1, s2), (s2, s3) = StS return [s1, 2*s2, s3] # Palindrome of odd length else: # Even central element if isEven(middle): # Multiplies R by last M R = R * M(middle // 2) # First row r, s = R[0] t, u = absBezout(s**2, r**2) S = Matrix([[r**2, t], [s**2, u]]) term = sqrtCfPalindromeFirstTerm(t*u) # Odd central element else: # Multiplies R by N R = R * N(middle) # First row r, s = R[0] if isEven(r) and isEven(s): return None t, u = absBezout(s**2, r**2) S = Matrix([[r**2, 2*t], [s**2, 2*u]]) term = sqrtCfPalindromeFirstTerm(2*t * 2*u) # First of the palindrome in cf(sqrt(t*u)) is first of input? if term != firstPart[0]: S = S * Smod # Product of the two rows of S * [x; 1] (s1, s2), (s3, s4) = S return [s1*s3, s1*s4 + s2*s3, s2*s4] # First term of the palidrome in the continued fraction of sqrt(n) def sqrtCfPalindromeFirstTerm(n): # How many decimal places we want nextDp = len(str(n)) # The precision must also include the number of digits of the # integer part of the sqrt(n) sqrtIntegerPartLength = (nextDp + 1) // 2 # This loop checks if the result is right # (Add/subtract 1 in the last decimal place of frac; # if the result changes in either case, term may be # wrong due to rounding) gotResult = False while not gotResult: # Some extra decimal places to avoid rounding errors # ('+ 3' isn't enough yet to avoid errors in all cases -- # ex.: pol4pal([104554254]) dp = nextDp + 3 decimal.getcontext().prec = sqrtIntegerPartLength + dp # To check the result (grain = 1 in the last decimal place) grain = Decimal(f'1e-{dp}') nextDp = 2*nextDp # Square root of n sqrt = Decimal(n).sqrt() # Fractional part frac = sqrt - sqrt.to_integral_value(decimal.ROUND_DOWN) # First term of the palindrome in the continued fraction term = (1 / frac).to_integral_value(decimal.ROUND_DOWN) # If false, term may be wrong due to rounding gotResult = ( (1 / (frac + grain)).to_integral_value(decimal.ROUND_DOWN) == term and (1 / (frac - grain)).to_integral_value(decimal.ROUND_DOWN) == term ) # Back to int return int(term) def M(n): return Matrix([[n, 1], [1, 0]]) def N(n): return Matrix([[n, 1], [2, 0]]) # Bézout's identity def bezout(a, b): if b == 0: return [1, 0] div = a // b rem = a % b t2, u2 = bezout(b, rem) t, u = u2, t2 - u2*div return [t, u] # Absolute values of Bézout's identity, avoiding zeros def absBezout(a, b): t, u = bezout(a, b) if t == 0 or u == 0: t = t + b u = u - a return [abs(t), abs(u)] def isEven(n): return n % 2 == 0 def isOdd(n): return not isEven(n) # Polynomial with the given coefficients def polynomial(coeffs): a, b, c = coeffs return lambda x: a*x**2 + b*x + c