#include <bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
const int G = 3;
using VI = vector<int>;
int fastpow(int a, int b, int m = MOD) {
int res = 1;
while (b) {
if (b & 1) res = 1LL * res * a % m;
a = 1LL * a * a % m;
b >>= 1;
}
return res;
}
void ntt(vector<int>& a, bool invert) {
int n = a.size();
for (int i = 1, j = 0; i < n; ++i) {
int bit = n >> 1;
for (; j & bit; bit >>= 1) j ^= bit;
j ^= bit;
if (i < j) swap(a[i], a[j]);
}
for (int len = 2; len <= n; len <<= 1) {
int wlen = fastpow(G, (MOD - 1) / len);
if (invert) wlen = fastpow(wlen, MOD - 2);
for (int i = 0; i < n; i += len) {
int w = 1;
for (int j = 0; j < len / 2; ++j) {
int u = a[i + j], v = 1LL * a[i + j + len / 2] * w % MOD;
a[i + j] = (u + v) % MOD;
a[i + j + len / 2] = (u - v + MOD) % MOD;
w = 1LL * w * wlen % MOD;
}
}
}
if (invert) {
int inv_n = fastpow(n, MOD - 2);
for (int& x : a)
x = 1LL * x * inv_n % MOD;
}
}
VI parse(const string& s) {
VI v;
for (char c : s)
v.push_back(c - '0');
reverse(v.begin(), v.end());
return v;
}
VI add(VI a, VI b) {
VI res;
int carry = 0;
for (int i = 0; i < max(a.size(), b.size()) || carry; i++) {
if (i == a.size()) a.push_back(0);
int val = a[i] + (i < b.size() ? b[i] : 0) + carry;
res.push_back(val % 10);
carry = val / 10;
}
while (res.size() > 1 && res.back() == 0) res.pop_back();
reverse(res.begin(), res.end());
return res;
}
VI sub(VI a, VI b) {
// assume a >= b
VI res;
int carry = 0;
for (int i = 0; i < a.size(); i++) {
int val = a[i] - (i < b.size() ? b[i] : 0) - carry;
if (val < 0) val += 10, carry = 1;
else carry = 0;
res.push_back(val);
}
while (res.size() > 1 && res.back() == 0) res.pop_back();
reverse(res.begin(), res.end());
return res;
}
VI multiply(VI a, VI b) {
int n = 1;
while (n < a.size() + b.size()) n <<= 1;
a.resize(n); b.resize(n);
ntt(a, false); ntt(b, false);
for (int i = 0; i < n; i++)
a[i] = 1LL * a[i] * b[i] % MOD;
ntt(a, true);
VI res(n);
int carry = 0;
for (int i = 0; i < n; i++) {
long long cur = a[i] + carry;
res[i] = cur % 10;
carry = cur / 10;
}
while (carry) {
res.push_back(carry % 10);
carry /= 10;
}
while (res.size() > 1 && res.back() == 0) res.pop_back();
reverse(res.begin(), res.end());
return res;
}
void print(VI v) {
for (int d : v) cout << d;
cout << '\n';
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
string A, B, op;
cin >> A >> op >> B;
auto a = parse(A);
auto b = parse(B);
if (op == "+") {
print(add(a, b));
} else if (op == "-") {
if (A == B) {
cout << 0 << '\n';
} else {
// a > b assumed
print(sub(a, b));
}
} else if (op == "*") {
print(multiply(a, b));
}
return 0;
}
#include <bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
const int G = 3;
using VI = vector<int>;
int fastpow(int a, int b) {
int res = 1;
int base = a;
while (b) {
if (b & 1) res = (int)((int64_t)res * base % MOD);
base = (int)((int64_t)base * base % MOD);
b >>= 1;
}
return res;
}
void ntt(VI &a, bool invert) {
int n = (int)a.size();
for (int i = 1, j = 0; i < n; ++i) {
int bit = n >> 1;
for (; j & bit; bit >>= 1)
j ^= bit;
j ^= bit;
if (i < j) swap(a[i], a[j]);
}
for (int len = 2; len <= n; len <<= 1) {
int wlen = fastpow(G, (MOD - 1) / len);
if (invert) wlen = fastpow(wlen, MOD - 2);
for (int i = 0; i < n; i += len) {
int w = 1;
int half = len >> 1;
for (int j = 0; j < half; ++j) {
int u = a[i + j];
int v = (int)((int64_t)a[i + j + half] * w % MOD);
a[i + j] = u + v < MOD ? u + v : u + v - MOD;
a[i + j + half] = u - v >= 0 ? u - v : u - v + MOD;
w = (int)((int64_t)w * wlen % MOD);
}
}
}
if (invert) {
int inv_n = fastpow(n, MOD - 2);
for (int &x : a)
x = (int)((int64_t)x * inv_n % MOD);
}
}
VI parse(const string &s) {
VI v(s.size());
for (int i = 0; i < (int)s.size(); i++)
v[(int)s.size() - 1 - i] = s[i] - '0'; // 反轉放好,後續不用再反轉
return v;
}
VI add(const VI &a, const VI &b) {
int n = (int)max(a.size(), b.size());
VI res;
res.reserve(n + 1);
int carry = 0;
for (int i = 0; i < n || carry; ++i) {
int x = carry;
if (i < (int)a.size()) x += a[i];
if (i < (int)b.size()) x += b[i];
carry = x / 10;
res.push_back(x % 10);
}
// 去除末尾多餘0
while (res.size() > 1 && res.back() == 0)
res.pop_back();
return res;
}
VI sub(const VI &a, const VI &b) {
// 假設 a >= b
VI res;
res.reserve(a.size());
int carry = 0;
for (int i = 0; i < (int)a.size(); i++) {
int x = a[i] - carry - (i < (int)b.size() ? b[i] : 0);
if (x < 0) x += 10, carry = 1;
else carry = 0;
res.push_back(x);
}
while (res.size() > 1 && res.back() == 0)
res.pop_back();
return res;
}
VI multiply(VI a, VI b) {
int n = 1;
while (n < (int)(a.size() + b.size()))
n <<= 1;
a.resize(n);
b.resize(n);
ntt(a, false);
ntt(b, false);
for (int i = 0; i < n; i++)
a[i] = (int)((int64_t)a[i] * b[i] % MOD);
ntt(a, true);
VI res;
res.reserve(n);
int64_t carry = 0;
for (int i = 0; i < n; i++) {
int64_t cur = a[i] + carry;
res.push_back(int(cur % 10));
carry = cur / 10;
}
while (carry > 0) {
res.push_back(int(carry % 10));
carry /= 10;
}
while (res.size() > 1 && res.back() == 0)
res.pop_back();
return res;
}
void print(const VI &v) {
for (int i = (int)v.size() - 1; i >= 0; i--)
cout << v[i];
cout << '\n';
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
string A, B, op;
cin >> A >> op >> B;
VI a = parse(A), b = parse(B);
if (op == "+") {
print(add(a, b));
} else if (op == "-") {
if (A == B) {
cout << 0 << '\n';
} else {
print(sub(a, b)); // 假設 A >= B
}
} else if (op == "*") {
print(multiply(a, b));
}
return 0;
}
AC (0.7s, 43.4MB) |
CPP
|
#include <cstdio> // For scanf, printf #include <string> // For std::string, std::to_string #include <vector> // For std::vector #include <algorithm> // For std::reverse, std::swap #include <cmath> // For std::acos, std::sin, std::cos, std::round #include <complex> // For std::complex (use with long double) #include <cstring> // For strlen #include <iostream> // For std::ios_base, std::cin, std::tie // Helper: Removes leading zeros from a string, handles "0" correctly std::string removeLeadingZeros(std::string s) { s.erase(0, s.find_first_not_of('0')); return s.empty() ? "0" : s; } // Adds two large numbers (string representation) std::string Add(std::string a, std::string b) { if (a.length() < b.length()) std::swap(a, b); // Ensure 'a' is longer std::string res = ""; int carry = 0; for (int i = 0; i < a.length(); ++i) { int sum = (a[a.length() - 1 - i] - '0') + carry; if (i < b.length()) sum += (b[b.length() - 1 - i] - '0'); res += (sum % 10 + '0'); carry = sum / 10; } if (carry) res += (carry + '0'); std::reverse(res.begin(), res.end()); return removeLeadingZeros(res); } // Subtracts two large numbers (string representation) std::string Sub(std::string a, std::string b) { bool neg = false; if (a.length() < b.length() || (a.length() == b.length() && a < b)) { neg = true; std::swap(a, b); } if (a == b) return "0"; std::string res = ""; int borrow = 0; for (int i = 0; i < a.length(); ++i) { int diff = (a[a.length() - 1 - i] - '0') - borrow; if (i < b.length()) diff -= (b[b.length() - 1 - i] - '0'); if (diff < 0) { diff += 10; borrow = 1; } else { borrow = 0; } res += (diff + '0'); } std::reverse(res.begin(), res.end()); return (neg ? "-" : "") + removeLeadingZeros(res); } // Constant for PI, now using long double for higher precision const long double PI = std::acos(-1.0L); // Use -1.0L for long double argument // Fast Fourier Transform (FFT) function - now using long double void fft(std::vector<std::complex<long double>>& a, bool invert) { // Changed template parameter to long double int n = a.size(); for (int i = 1, j = 0; i < n; i++) { int bit = n >> 1; for (; j & bit; bit >>= 1) j ^= bit; j ^= bit; if (i < j) std::swap(a[i], a[j]); } for (int len = 2; len <= n; len <<= 1) { long double ang = 2 * PI / len * (invert ? -1 : 1); // Angle in long double std::complex<long double> wlen(std::cos(ang), std::sin(ang)); // Complex numbers with long double for (int i = 0; i < n; i += len) { std::complex<long double> w(1); // Complex number with long double for (int j = 0; j < len / 2; j++) { std::complex<long double> u = a[i + j], v = a[i + j + len / 2] * w; // Complex operations a[i + j] = u + v; a[i + j + len / 2] = u - v; w *= wlen; } } } if (invert) for (auto& x : a) x /= n; } // Multiplies two large numbers using FFT, returns string std::string multiply_fft(const char* a_str, const char* b_str) { // Define the base and how many decimal digits it represents const int BASE = 10000; // Using base 10^4 const int POWER = 4; // Each FFT element represents 4 decimal digits int na_len = std::strlen(a_str); int nb_len = std::strlen(b_str); std::vector<std::complex<long double>> fa, fb; // Changed vector element type to long double // Convert string to vector of numbers in the chosen base (reverse order) for (int i = na_len - 1; i >= 0; ) { long long chunk = 0; long long p10 = 1; for (int j = 0; j < POWER && i >= 0; ++j) { chunk += (a_str[i--] - '0') * p10; p10 *= 10; } fa.push_back(static_cast<long double>(chunk)); // Cast chunk to long double } for (int i = nb_len - 1; i >= 0; ) { long long chunk = 0; long long p10 = 1; for (int j = 0; j < POWER && i >= 0; ++j) { chunk += (b_str[i--] - '0') * p10; p10 *= 10; } fb.push_back(static_cast<long double>(chunk)); // Cast chunk to long double } // Determine FFT size int fn = 1; while (fn < fa.size() + fb.size() - 1) fn <<= 1; // Minimum length for product polynomial fa.resize(fn); fb.resize(fn); // Perform FFT on both coefficient vectors fft(fa, false); fft(fb, false); // Multiply transformed polynomials element-wise for (int i = 0; i < fn; ++i) fa[i] *= fb[i]; // Perform inverse FFT fft(fa, true); // Extract results and handle carries in the new base std::vector<long long> res_digits; // No initial size, will resize as needed long long carry = 0; // Iterate until all coefficients are processed AND there is no more carry for (int i = 0; i < fn || carry; ++i) { // If current index goes beyond fa's original range (fn), then fa[i] is 0, // so current_val just consists of the carry long long current_val = (i < fn ? static_cast<long long>(std::round(fa[i].real())) : 0) + carry; // Ensure res_digits can hold the current index. Resize grows only if necessary. if (i >= res_digits.size()) { res_digits.resize(i + 1); } res_digits[i] = current_val % BASE; // Current digit is current_val % BASE carry = current_val / BASE; // Carry is current_val / BASE } // Remove leading zero chunks (from the most significant end) // Only remove if more than one chunk and the last one is 0. This ensures "0" itself is not trimmed. while (res_digits.size() > 1 && res_digits.back() == 0) { res_digits.pop_back(); } // Convert the result from the chosen base back to a decimal string std::string result_str = ""; // Append the most significant chunk first result_str += std::to_string(res_digits.back()); // Append remaining chunks, padding with leading zeros if necessary for (int i = res_digits.size() - 2; i >= 0; --i) { // Convert chunk to string std::string s_chunk = std::to_string(res_digits[i]); // Pad with leading zeros to POWER digits (e.g., "0078" for 78 if POWER=4) result_str += std::string(POWER - s_chunk.length(), '0') + s_chunk; } return removeLeadingZeros(result_str); // Final cleanup for any overall leading zeros (e.g. from BASE logic or `to_string` for single '0' number) } char a_str_input[1000002]; char b_str_input[1000002]; int main() { // Optimize C++ standard streams for competitive programming (though scanf/printf are used here) std::ios_base::sync_with_stdio(false); std::cin.tie(NULL); // This line requires #include <iostream> char op; // Read the two numbers and the operator. Cast to void to suppress -Wunused-result warning. (void)scanf("%s %c %s", a_str_input, &op, b_str_input); std::string answer; switch (op) { case '+': answer = Add(std::string(a_str_input), std::string(b_str_input)); break; case '-': answer = Sub(std::string(a_str_input), std::string(b_str_input)); break; case '*': answer = multiply_fft(a_str_input, b_str_input); break; default: fprintf(stderr, "Error: Unknown operator '%c'\n", op); return 1; } printf("%s\n", answer.c_str()); return 0; }
大佬我的程式哪邊還能優化?