显示原始代码
#include <bits/stdc++.h>
#define rep(i, n) for (int i = 1, i##_end_ = (n); i <= i##_end_; ++i)
using namespace std;
typedef pair<int, int> pii;
typedef long long ll;
const int maxn = 10000020, MOD = 998244353;
int muln(int x, int y) { return 1LL * x * y % MOD; }
int qpow(int x, int y) {
int ret = 1;
for (; y; y >>= 1, x = muln(x, x))
if (y & 1)
ret = muln(ret, x);
return ret;
}
int inv(int x) { return qpow(x, MOD - 2); }
int mo(int x) {
if (x >= MOD)
x -= MOD;
if (x < 0)
x += MOD;
return x;
}
int fac[maxn], ifac[maxn];
int C(int n, int m) {
if (n < m)
return 0;
return muln(fac[n], muln(ifac[m], ifac[n - m]));
}
const int INF = 0x3f3f3f3f;
int p2[maxn], pi2[maxn], p3[maxn];
int mm[2] = { -INF, -INF }, nn[2] = { -1, -1 }, ans[2] = { -1, -1 };
const int inv2 = inv(2);
const int G = inv2 + 1, invG = inv(G);
int brute(int n, int m) {
int res = 0;
for (int i = 0; i <= m; ++i) res = mo(res + muln(pi2[i], C(n, i)));
return res;
}
int S(int n, int m, int tp) {
if (m < 0) {
mm[tp] = -1, nn[tp] = n;
return ans[tp] = 0;
}
if (m > n)
m = n;
if (mm[tp] == -INF) {
mm[tp] = m, nn[tp] = n;
return ans[tp] = brute(n, m);
} else {
for (; mm[tp] < m; ++mm[tp]) ans[tp] = mo(ans[tp] + muln(pi2[mm[tp] + 1], C(nn[tp], mm[tp] + 1)));
for (; mm[tp] > m; --mm[tp]) ans[tp] = mo(ans[tp] - muln(pi2[mm[tp]], C(nn[tp], mm[tp])) + MOD);
for (; nn[tp] < n; ++nn[tp])
ans[tp] = mo(muln(ans[tp], G) - muln(C(nn[tp], mm[tp]), pi2[mm[tp] + 1] + MOD));
for (; nn[tp] > n; --nn[tp])
ans[tp] = muln(invG, mo(ans[tp] + muln(C(nn[tp] - 1, mm[tp]), pi2[mm[tp] + 1])));
return ans[tp];
}
}
int A(int st, int ed, int c) {
if (st > ed)
return 0;
return mo(S(c, ed, 0) - S(c, st - 1, 1));
}
void solve() {
int n, m, res = 0;
scanf("%d%d", &n, &m);
for (int i = 0; i <= n; ++i)
res = mo(res + muln(muln(muln(p3[i], p2[n - i]), C(n, i)), A(n - m - 2 * i, n + m - 2 * i, n - i)));
printf("%d\n", res);
}
int main() {
fac[0] = ifac[0] = p2[0] = pi2[0] = p3[0] = 1;
for (int i = 1; i <= 10000007; ++i) {
fac[i] = muln(fac[i - 1], i);
pi2[i] = muln(pi2[i - 1], inv2);
p2[i] = muln(p2[i - 1], 2);
p3[i] = muln(p3[i - 1], 3);
}
ifac[10000007] = inv(fac[10000007]);
for (int i = 10000006; i > 0; --i) ifac[i] = muln(ifac[i + 1], i + 1);
int T;
scanf("%d", &T);
while (T--) solve();
return 0;
}