Bitwise Recursion
Define the function f:N_0 x N_0 -> N_0 by: f(n, k) = n & if k = 0, f(n + k, floor(k/2)) & if k > 0, where + denotes bitwise XOR. Let S(N) = sum_(k=0)^N sum_(n=0)^N f(n, k). Find S(10^18) mod (10^9...
Problem Statement
This archive keeps the full statement, math, and original media on the page.
Let \(b(n)\) be the largest power of \(2\) that divides \(n\). For example \(b(24) = 8\).
Define the recursive function: \begin {align*} A(0) &= 1 \\ A(2n) &= 3A(n) + 5A(2n - b(n)) \qquad \text {for } n > 0 \\ A(2n + 1) &= A(n) \end {align*}
and let \(H(t, r) = A\left ((2^t + 1)^r\right )\).
You are given \(H(3, 2) = A(81) = 636056\).
Find \(H(10^{14} + 31, 62)\). Give your answer modulo \(\num {1000062031}\).
Problem 811: Bitwise Recursion
Mathematical Foundation
Let and write any non-negative integer in binary as with .
Lemma 1 (Closed form for ). For all ,
Proof. Define . The recursion unfolds as:
- Step 0: accumulator , remaining key .
- Step 1: accumulator , remaining key .
- Step : accumulator , remaining key .
The recursion terminates when , i.e., when . The final accumulator is .
Lemma 2 (Suffix XOR characterisation of ). Bit of equals the suffix XOR parity of the binary representation of starting at position :
Proof. Bit of is . Therefore
This is a finite sum since for .
Theorem 1 (Involution). The function is an involution on , i.e., for all .
Proof. Over , represent the binary digits of as a column vector . The map acts as , where is the lower-triangular matrix with for (i.e., all entries on and above the diagonal are 1). Explicitly, the -entry of over is
Wait — let us index correctly. We have if (upper-triangular all-ones). Then when , which equals . This is 1 iff , so over . Therefore .
Theorem 2 (Decomposition of ). Write . Then
where , with and .
Proof. We have
Bit of is 1 iff . Counting over independent choices of and :
Lemma 3 (Computing ). The count is given by standard bit-counting:
Proof. Among , bit cycles with period : it is 0 for values then 1 for values. Counting complete cycles plus the partial remainder yields the formula.
Lemma 4 (Computing via digit DP). The count can be computed using a digit dynamic programming approach that processes the bits of from MSB to LSB, tracking: (i) whether the prefix of is still tight to , and (ii) the running suffix XOR parity from bit onward. This requires states per bit position .
Proof. The digit DP is a standard technique for counting integers in satisfying a bitwise predicate. For each bit , the suffix XOR parity involves bits at positions , which are determined as we process from MSB down. The tight/free flag doubles the state count, giving states per bit and total across all .
Editorial
f(n, k) = n if k = 0 = f(n XOR k, k >> 1) if k > 0 Key insight: f(n, k) = n XOR g(k) where g(k) = XOR of all right-shifts of k. Bit j of g(k) = XOR of bits j, j+1, j+2, … of k (suffix XOR parity). We compute S(N) = sum_{k=0}^{N} sum_{n=0}^{N} f(n, k) mod (10^9 + 7) = sum_{k=0}^{N} sum_{n=0}^{N} (n XOR g(k)). We compute A_j: count of n in [0,N] with bit j set. We then compute B_j via digit DP on k in [0,N]. Finally, tracking suffix XOR parity from bit j onward.
Pseudocode
Compute A_j: count of n in [0,N] with bit j set
Compute B_j via digit DP on k in [0,N]
tracking suffix XOR parity from bit j onward
DP over bits B-1 down to 0
State: (tight, suffix_xor_parity)
tight: whether k's prefix matches N's prefix exactly
suffix_xor_parity: running XOR of bits at positions >= target_bit
Returns count of k in [0,N] with suffix_xor_parity = 1 at end
Complexity Analysis
- Time: where for . The digit DP for each of bit positions runs in time, giving total operations.
- Space: for the DP states (constant number of states per bit level).
Answer
Code
Each problem page includes the exact C++ and Python source files from the local archive.
#include <bits/stdc++.h>
using namespace std;
/*
* Problem 811: Bitwise Recursion
*
* f(n, k) = n if k = 0
* = f(n XOR k, k >> 1) if k > 0
*
* Key insight: f(n, k) = n XOR g(k), where g(k) = k XOR (k>>1) XOR (k>>2) XOR ...
* Bit j of g(k) = XOR of bits j, j+1, j+2, ... of k (suffix XOR parity).
*
* S(N) = sum_{k=0}^{N} sum_{n=0}^{N} (n XOR g(k))
*
* For each bit position j, count pairs (n,k) in [0,N]^2 where
* bit j of (n XOR g(k)) is 1, then S(N) = sum_j 2^j * count_j.
*/
const long long MOD = 1e9 + 7;
long long power(long long base, long long exp, long long mod) {
long long result = 1;
base %= mod;
while (exp > 0) {
if (exp & 1) result = result * base % mod;
base = base * base % mod;
exp >>= 1;
}
return result;
}
// Count k in [0..N] where suffix XOR parity from bit j onward is 1
// Uses digit DP on binary representation of N
long long count_suffix_xor_odd(long long N, int j) {
if (N < 0) return 0;
int B = 63 - __builtin_clzll(N + 1); // number of bits needed
if (j > B) return 0;
// Extract bits of N from MSB to LSB
vector<int> bits(B + 1);
for (int i = B; i >= 0; i--) {
bits[B - i] = (N >> i) & 1;
}
// dp[tight][parity] = count
// tight: are we still bounded by N?
// parity: XOR of bits at positions >= j seen so far
map<pair<bool,int>, long long> dp;
dp[{true, 0}] = 1;
for (int i = 0; i <= B; i++) {
int actual_bit_pos = B - i;
map<pair<bool,int>, long long> new_dp;
for (auto& [state, cnt] : dp) {
auto [tight, par] = state;
int max_d = tight ? bits[i] : 1;
for (int d = 0; d <= max_d; d++) {
bool new_tight = tight && (d == bits[i]);
int new_par = par;
if (actual_bit_pos >= j) {
new_par = par ^ d;
}
new_dp[{new_tight, new_par}] += cnt;
}
}
dp = new_dp;
}
long long result = 0;
for (auto& [state, cnt] : dp) {
if (state.second == 1) result += cnt;
}
return result;
}
// Count n in [0..N] with bit j set
long long count_bit_set(long long N, int j) {
if (N < 0) return 0;
long long full = (N + 1) >> (j + 1);
long long rem = (N + 1) & ((1LL << (j + 1)) - 1);
return full * (1LL << j) + max(0LL, rem - (1LL << j));
}
long long solve(long long N) {
int B = 63;
if (N > 0) B = 63 - __builtin_clzll(N);
long long total = 0;
for (int j = 0; j <= B; j++) {
long long ones_n = count_bit_set(N, j) % MOD;
long long zeros_n = ((N + 1) % MOD - ones_n % MOD + MOD) % MOD;
long long ones_k = count_suffix_xor_odd(N, j) % MOD;
long long zeros_k = ((N + 1) % MOD - ones_k % MOD + MOD) % MOD;
// Pairs where n_j XOR g(k)_j = 1
long long count_j = (ones_n * zeros_k % MOD + zeros_n * ones_k % MOD) % MOD;
total = (total + count_j % MOD * power(2, j, MOD) % MOD) % MOD;
}
return total;
}
// Brute force for verification
long long g_func(long long k) {
long long result = 0;
while (k) {
result ^= k;
k >>= 1;
}
return result;
}
long long solve_brute(int N) {
long long total = 0;
for (int k = 0; k <= N; k++) {
long long gk = g_func(k);
for (int n = 0; n <= N; n++) {
total += (n ^ gk);
}
}
return total;
}
int main() {
// Cross-verify on small cases
for (int N = 1; N <= 20; N++) {
long long brute = solve_brute(N);
long long dp = solve(N);
assert(brute % MOD == dp);
}
// Compute answer for N = 10^18
long long N = 1000000000000000000LL;
cout << solve(N) << endl;
return 0;
}
"""
Problem 811: Bitwise Recursion
f(n, k) = n if k = 0
= f(n XOR k, k >> 1) if k > 0
Key insight: f(n, k) = n XOR g(k) where g(k) = XOR of all right-shifts of k.
Bit j of g(k) = XOR of bits j, j+1, j+2, ... of k (suffix XOR parity).
We compute S(N) = sum_{k=0}^{N} sum_{n=0}^{N} f(n, k) mod (10^9 + 7)
= sum_{k=0}^{N} sum_{n=0}^{N} (n XOR g(k))
"""
MOD = 10**9 + 7
def g(k):
"""Compute g(k) = XOR of k, k>>1, k>>2, ... (suffix XOR parity)."""
result = 0
while k:
result ^= k
k >>= 1
return result
def xor_sum(N, c):
"""Compute sum_{n=0}^{N} (n XOR c) using bitwise analysis."""
if N < 0:
return 0
total = 0
# For each bit position j, count how many n in [0..N] have bit j set in (n XOR c)
for j in range(61):
bit_c = (c >> j) & 1
# Count of n in [0..N] with bit j set
full_cycles = (N + 1) >> (j + 1)
remainder = (N + 1) & ((1 << (j + 1)) - 1)
ones_in_bit_j = full_cycles * (1 << j) + max(0, remainder - (1 << j))
if bit_c == 1:
# XOR flips: ones become zeros and vice versa
count_set = (N + 1) - ones_in_bit_j
else:
count_set = ones_in_bit_j
total += count_set * (1 << j)
return total
# --- Method 1: Brute force for small N ---
def solve_brute(N):
"""Brute force: directly compute S(N)."""
total = 0
for k in range(N + 1):
gk = g(k)
for n in range(N + 1):
total += n ^ gk
return total
# --- Method 2: Using xor_sum with g(k) enumeration (medium N) ---
def solve_medium(N):
"""Compute S(N) by iterating over k and using xor_sum for inner sum."""
total = 0
for k in range(N + 1):
gk = g(k)
total += xor_sum(N, gk)
return total
# --- Method 3: Digit DP for large N ---
def solve_digit_dp(N, mod):
"""
Compute S(N) = sum_{k=0}^{N} sum_{n=0}^{N} (n XOR g(k)) mod `mod`
using digit DP on both k and n simultaneously.
For each bit position j, we need to count:
For how many (n, k) pairs in [0,N]x[0,N] is bit j of (n XOR g(k)) equal to 1?
Then S(N) = sum_j 2^j * count_j.
For bit j of (n XOR g(k)): this is n_j XOR g(k)_j.
g(k)_j = k_j XOR k_{j+1} XOR k_{j+2} XOR ...
We compute per-bit contributions.
"""
if N < 0:
return 0
B = N.bit_length()
total = 0
for j in range(B + 1):
# Count pairs (n, k) in [0..N]^2 where n_j XOR suffix_xor_j(k) = 1
# Count of n in [0..N] with bit j = 1
full_n = (N + 1) >> (j + 1)
rem_n = (N + 1) & ((1 << (j + 1)) - 1)
ones_n = full_n * (1 << j) + max(0, rem_n - (1 << j))
zeros_n = (N + 1) - ones_n
# Count of k in [0..N] with g(k)_j = 1
# g(k)_j = suffix XOR parity from bit j onward
# We need to count k in [0..N] where bits j,j+1,...,B-1 of k have odd parity
# This requires digit DP on k
ones_k = count_suffix_xor_odd(N, j, B)
zeros_k = (N + 1) - ones_k
# Pairs where n_j XOR g(k)_j = 1: ones_n * zeros_k + zeros_n * ones_k
count_j = (ones_n % mod) * (zeros_k % mod) + (zeros_n % mod) * (ones_k % mod)
total = (total + (count_j % mod) * pow(2, j, mod)) % mod
return total
def count_suffix_xor_odd(N, j, B):
"""
Count k in [0..N] where XOR of bits j, j+1, ..., B-1 of k is 1.
Uses digit DP on bits from MSB down.
"""
if j >= B:
return 0
# Extract bits of N
bits = []
for i in range(B - 1, -1, -1):
bits.append((N >> i) & 1)
# DP states: (position, tight, xor_parity_of_bits_from_j_onward_so_far)
# We process bits from MSB (position 0) to LSB (position B-1)
# Position i corresponds to bit (B-1-i)
# We track parity of bits at positions >= j
# dp[tight][parity] = count
dp = {}
dp[(True, 0)] = 1 # start: tight, parity 0
for i in range(B):
actual_bit_pos = B - 1 - i
new_dp = {}
for (tight, par), cnt in dp.items():
max_d = bits[i] if tight else 1
for d in range(max_d + 1):
new_tight = tight and (d == bits[i])
new_par = par
if actual_bit_pos >= j:
new_par = par ^ d
key = (new_tight, new_par)
new_dp[key] = new_dp.get(key, 0) + cnt
dp = new_dp
# Count entries with parity = 1
result = 0
for (tight, par), cnt in dp.items():
if par == 1:
result += cnt
return result
# --- Verify small cases ---
N_small = 15
brute_ans = solve_brute(N_small)
medium_ans = solve_medium(N_small)
dp_ans = solve_digit_dp(N_small, 10**18 + 9) # use large mod to get exact
assert brute_ans == medium_ans, f"Brute vs medium mismatch: {brute_ans} vs {medium_ans}"
assert brute_ans == dp_ans, f"Brute vs DP mismatch: {brute_ans} vs {dp_ans}"
# Verify g is an involution
for k in range(100):
assert g(g(k)) == k, f"g is not involution at k={k}"
# Verify g values
assert g(0) == 0
assert g(1) == 1
assert g(2) == 3
assert g(3) == 2
assert g(7) == 5
assert g(13) == 9
# Verify concrete f values
assert 5 ^ g(6) == 1 # f(5, 6) = 1
assert 5 ^ g(4) == 2 # f(5, 4) = 2
# --- Compute answer for large N ---
N_large = 10**18
answer = solve_digit_dp(N_large, MOD)
print(answer)