fork download
  1. #include <iostream>
  2. #include <vector>
  3. #include <numeric>
  4. #include <algorithm>
  5.  
  6. using namespace std;
  7. using ll = long long;
  8.  
  9. const ll MOD = 998244353;
  10.  
  11. ll exgcd(ll a, ll b, ll &x, ll &y) {
  12. if (b == 0) { x = 1; y = 0; return a; }
  13. ll x1, y1;
  14. ll d = exgcd(b, a % b, x1, y1);
  15. x = y1;
  16. y = x1 - y1 * (a / b);
  17. return d;
  18. }
  19.  
  20. ll count_valid(ll N, ll A, ll B, ll C, ll D, ll d) {
  21. ll targetA = (d - B % d) % d;
  22. ll x, y;
  23. ll gA = exgcd(A, d, x, y);
  24. if (targetA % gA != 0) return 0;
  25.  
  26. ll stepA = d / gA;
  27. x = (x % stepA + stepA) % stepA;
  28. ll mul = targetA / gA;
  29. ll X = (x * (mul % stepA)) % stepA;
  30. X = (X + stepA) % stepA;
  31.  
  32. ll targetC = (d - D % d) % d;
  33. ll At = (C * stepA) % d;
  34. ll Bt = (targetC - (C * X) % d) % d;
  35. Bt = (Bt + d) % d;
  36.  
  37. ll tx, ty;
  38. ll gt = exgcd(At, d, tx, ty);
  39. if (Bt % gt != 0) return 0;
  40.  
  41. ll stepT = d / gt;
  42. tx = (tx % stepT + stepT) % stepT;
  43. ll mulT = Bt / gt;
  44. ll T = (tx * (mulT % stepT)) % stepT;
  45. T = (T + stepT) % stepT;
  46.  
  47. ll M = stepT * stepA;
  48. ll startX = (X + T * stepA) % M;
  49. if (startX == 0) startX = M;
  50.  
  51. if (startX > N) return 0;
  52. return (N - startX) / M + 1;
  53. }
  54.  
  55. void solve() {
  56. ll N, A, B, C, D;
  57. cin >> N >> A >> B >> C >> D;
  58.  
  59. ll delta = abs(A * D - B * C);
  60. if (delta == 0) {
  61. ll g1 = __gcd(A, C);
  62. ll a = A / g1;
  63. ll k = B / a;
  64.  
  65. ll sum_i = (N % MOD) * ((N + 1) % MOD) % MOD;
  66. sum_i = sum_i * 499122177 % MOD;
  67. ll ans = (g1 % MOD) * sum_i % MOD;
  68. ans = (ans + (k % MOD) * (N % MOD)) % MOD;
  69. cout << ans << "\n";
  70. return;
  71. }
  72.  
  73. vector<ll> divs;
  74. for (ll i = 1; i * i <= delta; ++i) {
  75. if (delta % i == 0) {
  76. divs.push_back(i);
  77. if (i * i != delta) {
  78. divs.push_back(delta / i);
  79. }
  80. }
  81. }
  82. sort(divs.begin(), divs.end());
  83.  
  84. int m = divs.size();
  85. vector<ll> exact(m, 0);
  86. for (int i = 0; i < m; ++i) {
  87. exact[i] = count_valid(N, A, B, C, D, divs[i]);
  88. }
  89.  
  90. for (int i = m - 1; i >= 0; --i) {
  91. for (int j = i + 1; j < m; ++j) {
  92. if (divs[j] % divs[i] == 0) {
  93. exact[i] -= exact[j];
  94. }
  95. }
  96. }
  97.  
  98. ll ans = 0;
  99. for (int i = 0; i < m; ++i) {
  100. ans = (ans + (exact[i] % MOD) * (divs[i] % MOD)) % MOD;
  101. }
  102. cout << ans << "\n";
  103. }
  104.  
  105. int main() {
  106. ios_base::sync_with_stdio(false);
  107. cin.tie(NULL);
  108. int T;
  109. if (cin >> T) {
  110. while (T--) solve();
  111. }
  112. return 0;
  113. }
Success #stdin #stdout 0.01s 5280KB
stdin
3
4 1 3 4 2
100000000 1 1 1 2
100000000 1 1 1 1
stdout
10
100000000
822404071