#!/bin/bash

. /usr/sbin/so-common

if [[ $1 =~ ^(-q|--quiet) ]]; then
	quiet=true
elif [[ $1 =~ ^(-v|--verbose) ]]; then
	verbose=true
fi

sshd_config=/etc/ssh/sshd_config
temp_config=/tmp/sshd_config

before=
after=
reload_required=false
change_header_printed=false

check_sshd_t() {
	local string=$1

	local grep_out
	grep_out=$(sshd -T | grep "^${string}")

	before=$grep_out
}

print_diff() {
	local diff
	diff=$(diff -dbB <(echo $before) <(echo $after) | awk 'NR>1')

	if [[ -n $diff ]]; then
		if [[ $change_header_printed == false ]]; then
			printf '%s\n' '' "Changes" '-------' ''
			change_header_printed=true
		fi
		echo -e "$diff\n"
	fi
}

replace_or_add() {
	local type=$1
	local string=$2
	if grep -q "$type" $temp_config; then
		sed -i "/$type .*/d" $temp_config
	fi
	printf "%s\n\n" "$string" >> $temp_config
	reload_required=true
}

test_config() {
	local msg
	msg=$(sshd -t -f $temp_config)
	local ret=$?

	if [[ -n $msg ]]; then
		echo "Error found in temp sshd config:"
		echo $msg
	fi

	return $ret
}

main() {
	if ! [[ $quiet ]]; then echo "Copying current config to $temp_config"; fi
	cp $sshd_config $temp_config

	# Add newline to ssh for legibility
	echo "" >> $temp_config

	# Ciphers
	check_sshd_t "ciphers"

	local bad_ciphers=(
		"3des-cbc"
		"aes128-cbc"
		"aes192-cbc"
		"aes256-cbc"
		"arcfour"
		"arcfour128"
		"arcfour256"
		"blowfish-cbc"
		"cast128-cbc"
	)

	local cipher_string=$before
	for cipher in "${bad_ciphers[@]}"; do
		cipher_string=$(echo "$cipher_string" | sed "s/${cipher}\(,\|\$\)//g" | sed 's/,$//')
	done

	after=$cipher_string

	if [[ $verbose ]]; then print_diff; fi

	if [[ $before != "$after" ]]; then
		replace_or_add "ciphers" "$cipher_string" && test_config || exit 1
	fi

	# KexAlgorithms
	check_sshd_t "kexalgorithms"

	local bad_kexalgs=(
		"diffie-hellman-group-exchange-sha1"
		"diffie-hellman-group-exchange-sha256"
		"diffie-hellman-group1-sha1"
		"diffie-hellman-group14-sha1"
		"ecdh-sha2-nistp256"
		"ecdh-sha2-nistp521"
		"ecdh-sha2-nistp384"
	)

	local kexalg_string=$before
	for kexalg in "${bad_kexalgs[@]}"; do
		kexalg_string=$(echo "$kexalg_string" | sed "s/${kexalg}\(,\|\$\)//g" | sed 's/,$//')
	done

	after=$kexalg_string

	if [[ $verbose ]]; then print_diff; fi

	if [[ $before != "$after" ]]; then
		replace_or_add "kexalgorithms" "$kexalg_string" && test_config || exit 1
	fi

	# Macs
	check_sshd_t "macs"
	
	local bad_macs=(
		"hmac-sha2-512"
		"umac-128@openssh.com"
		"hmac-sha2-256"
		"umac-64@openssh.com"
		"hmac-sha1"
		"hmac-sha1-etm@openssh.com"
		"umac-64-etm@openssh.com"
	)

	local macs_string=$before
	for mac in "${bad_macs[@]}"; do
		macs_string=$(echo "$macs_string" | sed "s/${mac}\(,\|\$\)//g" | sed 's/,$//')
	done

	after=$macs_string

	if [[ $verbose ]]; then print_diff; fi

	if [[ $before != "$after" ]]; then
		replace_or_add "macs" "$macs_string" && test_config || exit 1
	fi

	# HostKeyAlgorithms
	check_sshd_t "hostkeyalgorithms"
	
	local optional_suffix_regex_hka="\(-cert-v01@openssh.com\)\?"
	local bad_hostkeyalg_list=(
		"ecdsa-sha2-nistp256"
		"ecdsa-sha2-nistp384"
		"ecdsa-sha2-nistp521"
		"ssh-rsa"
		"ssh-dss"
	)

	local hostkeyalg_string=$before
	for alg in "${bad_hostkeyalg_list[@]}"; do
		hostkeyalg_string=$(echo "$hostkeyalg_string" | sed "s/${alg}${optional_suffix_regex_hka}\(,\|\$\)//g" | sed 's/,$//')
	done

	after=$hostkeyalg_string
	
	if [[ $verbose ]]; then print_diff; fi
	
	if [[ $before != "$after" ]]; then
		replace_or_add "hostkeyalgorithms" "$hostkeyalg_string" && test_config || exit 1
	fi

	if [[ $reload_required == true ]]; then
		mv -f $temp_config $sshd_config
		if ! [[ $quiet ]]; then echo "Reloading sshd to load config changes"; fi
		systemctl reload sshd
		echo "[ WARNING ] Any new ssh sessions will need to remove and reaccept the host key fingerprint for this server before reconnecting."
	else
		if ! [[ $quiet ]]; then echo "No changes made to temp file, cleaning up"; fi
		rm -f $temp_config
	fi
}

main
