#!/usr/bin/python3
#
# Maximilian Wilhelm <max@rfc2324.org>
#  --  Mon 30 Mar 2020 11:55:47 PM CEST
#

import argparse
from dns.flags import to_text
from dns.resolver import Resolver
from ipaddress import ip_address
import sys
import time

# Exit code definitions
OK = 0
WARNING = 1
CRITICAL = 2
UNKNOWN = 3

# Track start time
time_start = time.time ()

parser = argparse.ArgumentParser (description = 'Check DNS sync')
parser.add_argument ('--reference-ns', required = True, help = 'IP address of reference NS')
parser.add_argument ('--replica-ns', required = True, help = 'IP address of NS to be checked')
parser.add_argument ('--check-mode', choices = [ 'serial', 'axfr' ], default = 'serial', help = 'Compare only serial or full zone content?')
parser.add_argument ('--timeout', type = int, default = 10, help = 'Timeout for DNS operations')
parser.add_argument ('--verbose', '-v', action = 'store_true', help = 'Be verbose in the output')
parser.add_argument ('zones', nargs = '+', help = 'Zones to compare')

args = parser.parse_args ()

if args.check_mode == 'axfr':
	print ("AXFR check mode not implemented yet. Send patches :)")
	sys.exit (UNKNOWN)

#
# Helpers
#

def is_ip (ns):
	try:
		ip = ip_address (ns)
	except ValueError:
		return False

	return True


def check_zone (zone):
	res = {
		'state' : UNKNOWN,
		'diff' : '',
		'errors' : '',
	}

	if args.check_mode == 'serial':
		try:
			# Query reference NS
			reference = reference_res.query (zone, 'SOA')

			# Check is answer is authoritive
			if not 'AA' in to_text (reference.response.flags):
				res['state'] = CRITICAL
				res['errors'] = "Got non-authoritive answer from reference NS: %s" % args.reference_ns
				return res
		except Exception as e:
			res['errors'] = "Error while checking reference NS %s: %s" % (args.reference_ns, e)
			return res

		try:
			# Query replica NS
			replica = replica_res.query (zone, 'SOA')

			# Check is answer is authoritive
			if not 'AA' in to_text (replica.response.flags):
				res['state'] = CRITICAL
				res['errors'] = "Got non-authoritive answer from replica NS: %s" % args.replica_ns
				return res
		except Exception as e:
			res['errors'] = "Error while checking replica NS %s: %s" % (args.replica_ns, e)
			return res

		try:
			reference_serial = str (reference.response.answer[0]).split ()[6]
			replica_serial = str (replica.response.answer[0]).split ()[6]
		except AttributeError as a:
			res['errors'] = a
			return res
		except IndexError as i:
			res['errors'] = i
			return res

		if reference_serial == replica_serial:
			res['state'] = OK
		else:
			res['state'] = CRITICAL
			res['errors'] = "Serial mismatch: %s vs. %s" % (reference_serial, replica_serial)

	return res


#
# Setup
#

# Check for possible badness
if not is_ip (args.reference_ns):
	print ("Error: Reference NS has to an IP address.")
	sys.exit (CRITICAL)

if not is_ip (args.replica_ns):
	print ("Error: Replica NS has to an IP address.")
	sys.exit (CRITICAL)

if args.reference_ns == args.replica_ns:
	print ("Error: Reference NS and replica NS must not be the same!")
	sys.exit (CRITICAL)


# Resolver for reference NS
reference_res = Resolver (configure = False)
reference_res.nameservers = [args.reference_ns]
reference_res.lifetime = args.timeout

# Resolver for NS to be checked
replica_res = Resolver (configure = False)
replica_res.nameservers = [args.replica_ns]
replica_res.lifetime = args.timeout


#
# Let#s go
#

codes = {}
ret_code = OK
errors = ""
in_sync = []

for zone in args.zones:
	check = check_zone (zone)

	# Keep track of states
	state = check['state']
	codes[state] = codes.get (state, 0) + 1

	if state == OK:
		in_sync.append (zone)
		continue

	errors += "Zone '%s': %s\n" % (zone, check['errors'])

	if state > ret_code:
		ret_code = check['state']

if errors:
	print (errors)

if in_sync:
	if args.verbose:
		print ("Zones in sync: %s" % ", ".join (sorted (in_sync)))

time_delta = int (1000 * (time.time () - time_start))

print ("Checked %d zones in %d ms. %d OK, %d WARN, %d CRIT, %d UNKN" % (
	len (args.zones),
	time_delta,
	codes.get (OK, 0),
	codes.get (WARNING, 0),
	codes.get (CRITICAL ,0),
	codes.get (UNKNOWN, 0),
))

sys.exit (ret_code)