def fuss_catalan(k, n, prime_mod=0):
if prime_mod == 0:
return math.comb(n * k, n) // ((k - 1) * n + 1)
numer, denom = 1, (k - 1) * n + 1
for nu, de in zip(range(n * (k - 1) + 1, n * k + 1), range(1, n + 1)):
numer = numer * nu % prime_mod
denom = denom * de % prime_mod
return numer * pow(denom, -1, prime_mod) % prime_mod