Modular Arithmetic Review
Modular arithmetic simplified is the remainder of integer division; in the example sense, for any values \(a,b\) and a non-zero value \(m\), \(a \cong b \text{ (mod } m)\) if and only if when dividing \(a\) by \(n\) the remainder is equal to that of dividing \(b\) by \(n\). So \(8 \cong 5 \text{ (mod } 3)\) because 8 divided by 3 has remainder 2, and 5 divided by 3 has remainder 2. Most operations are simple in modular arithmetic, addition/subtraction/multiplication can be done normally then taking the modulo of the total. Division is a bit more complex as it requires determing the inverse; for any integers \(a,m\), \(b\) is an inverse of \(a\) if \(ab \cong 1 \text{ (mod } m)\). This equation can be solved as a Diophantine Equation or by utilization of Euler’s Theorem. Specifically, if the modulo \(m\) is prime, then this can be similifed further to Fermat’s Little Theorem and the inverse for ANY non-zero integer \(a\) in modulo \(p\) is equal to \(a^{p-2} \text{ mod } p\).
Most competitive programmers use C++, and the common method to implementing modular exponentiation uses Binary Exponentiation. Conviniently in Python, this is already implemented in the built in pow function, so by the above, running pow(a,p-2,p) will compute the modular inverse in \(O(\log p)\) time, which for most practical cases is fast enough. Even if you forget Euler’s Theorem, the modular inverse can still be computed via pow(a,-1,p) (assuming it exists).
General Hashing Formula
We can then use modular arithmetic to effectively hash (assign) each character in a string array a unique value. Let character c be in index i in a string. This character is assigned the value ord(c). ord(c)+1 is used to convert a character to a positive integer from 1 to 128 (usually) depending on its ASCII value. This setup will be sufficient for most general strings comprising of uppercase/lowercase characters, numbers, and symbols. Most problems don’t use all of these and this part of the expression can vary slightly. For instance, if only lower case letters are being used, then ord(c)-96 may be more ideal to map all the lower case letters to the integers from 1 to 26. Using this mapping system, we can map the string “helloworld” to the array [8,5,12,12,15,23,15,18,12,4].
We can now use this array to precompute the hash values for ALL of the possible substrings in \(O(n^2)\) where \(n\) is the length of the array using the following code:
# create reference table
hashstrings = list()
for _ in range(n):
tmp = [0]*n
hashstrings.append(tmp)
p = 10**17-3 # large prime for avoiding hash collisions
b = 29 # base for exponents
for i in range(n):
h = 0
for j in range(i,n):
h = (h*b + ar[j]) % p
hashstrings[i][j] = hThis code works by induction where for a substring from index i to j, the hash value is
\[ \sum_{k = i}^j 29^{j-k}*\text{ar[k]} \mod {10^{17}-3} \]
Another way to interpret it is that the hash value of the substring equals the hash value of the last character, 29 times the hash value of the second last character, \(29^2\) times the hash value of the third last character, and every previous character is 29 times the value of the previous. Our base case is an empty string, which has a hash value of 0. Then multiplying by 29 “shifts” all of the characters in the string 1 place to the left, alowing for the new character’s hash value to be added. This way, each new substring’s hash is computed from the hash value of the prefix of the string missing the last character.
Example Problem
Codeforces Round 166 (Div 2.) Problem D
We can then observe usage of this basic hashing method in this example problem. For this problem, the above code can be slightly modified to compute the hash value of each substring that has at most k bad characters. Once this limit is reached, the inner loop can be broken as all further strings would not be good. To track number of unique substrings, you can store the hash values in a dictionary and then count the number of keys.
import sys
#input functions
readint = lambda: int(sys.stdin.readline())
readints = lambda: map(int,sys.stdin.readline().split())
readar = lambda: list(map(int,sys.stdin.readline().split()))
flush = lambda: sys.stdout.flush()
readin = lambda: sys.stdin.readline()[:-1]
readins = lambda: map(str,sys.stdin.readline().split())
s = readin()
goodbadstr = readin()
k = readint()
# determine bad characters
badch = {}
for i in range(26):
if goodbadstr[i] == "0": badch[chr(i+97)] = 1
# compute hashes
m = 10**17-3
b = 29
hashvals = {}
n = len(s)
for i in range(n):
badcount = 0
h = 0
for j in range(i,n):
if badch.get(s[j]) == 1: badcount += 1
if badcount > k: break # too many bad characters
h = (h*b + ord(s[j])-96) % m # update hash
hashvals[h] = 1
print(len(list(hashvals.keys())))