ABC179 E問題: 数列の各項の剰余における周期性
詳細は上記リンクより参照.
問題の概要:
1 <= N <= 10^10 0 <= X < M <= 10^5 N, X, Mはそれぞれ整数
という制約下において,
初期値: A_1 = X 漸化式: A_(n+1) = (A_n)^2 mod M
で定義される数列A
について,
A_1
からA_N
までの各項の和を求めるというもの.
この手の問題によく触れている人ならば数列の定義を見た瞬間, 項に周期性があることは想像できる.
周期性を持つ根拠を簡単に説明すると, 以下のようになる.
数列A
の定義より, A
の各項は高々M
通り(0, ... , M-1)
.
よって鳩の巣原理より, A_1
からA_(M+1)
までの各項には同じ値を持つ2項が必ず存在する.
この2項をそれぞれA_(n1), A_(n2) (n1<n2)
とする.
ここで, 数列A
の定義より,
A_(n1) = A_(n2) であるならば, A_(n1+1) = (A_(n1))^2 mod M = (A_(n2))^2 mod M = A_(n2+1) 同様にして, A_(n1+2) = A_(n2+2) A_(n1+3) = A_(n2+3) ...
よって, n1
個目以上である数列A
の各項において
A_n = A_(n+n2-n1) (n>=n1)
である. これは数列A
の各項がn1
個目以上においてループし, その周期はn2-n1
であることを意味している.
実装方法については, 愚直に最大M+1
回 数列A
の各項を計算しループを検出する, ダブリングを用いる, などが考えられるが, 制約上M
は最大10^5
であるため計算量には余裕があると判断したことから, シンプルに実装することを選んだ.
ループの周期は自由に選んで問題ないので, 数列A
の各項をM+1
個計算し, A_(M+1)
のインデックスをn2
, A_(M+1)
と一致する値を持つ最も後ろの項のインデックスをn1
とし, 数列A
を
- ループ開始前
- ループ部
- ループ部が終了するまえに
N項目
に達する部分
の3通りに分けて計算した.
#include <algorithm> #include <bitset> #include <cassert> #include <cctype> #include <chrono> #include <climits> #include <cmath> #include <complex> #include <cstdio> #include <cstring> #include <deque> #include <fstream> #include <functional> #include <iomanip> #include <iostream> #include <iterator> #include <list> #include <map> #include <numeric> #include <queue> #include <random> #include <set> #include <sstream> #include <stack> #include <string> #include <utility> #include <valarray> #include <unordered_map> #include <unordered_set> #include <vector> // namespace using namespace std; // type using ll = long long; using ull = unsigned long long; using ld = long double; using PAIR = pair<int, int>; using PAIRLL = pair<ll, ll>; // input #define INIT ({ ios::sync_with_stdio(false); cin.tie(0); }) #define VAR(type, ...) type __VA_ARGS__; var_scan(__VA_ARGS__); template<typename T> static void var_scan(T& t) { cin >> t; } template<typename T, typename... R> static void var_scan(T& t, R&... rest) { cin >> t; var_scan(rest...); } // output #define OUT(var) cout << (var) #define SPC cout << ' ' #define TAB cout << '\t' #define BR cout << '\n' #define ENDL cout << endl #define FLUSH cout << flush // debug #ifdef DEBUG #define EPRINTF(fmt, ...) fprintf(stderr, fmt, __VA_ARGS__) #define DUMP(var) cerr << #var << '\t' << (var) << '\n' #define DUMPITERABLE(i) do { \ cerr << #i << '\t'; \ for (const auto& it : i) { cerr << it << ' '; } \ cerr << '\n'; \ } while (0) #define DUMPITERABLE2(i2) do { \ cerr << #i2 << '\t'; \ for (const auto& it : i2) { \ for (const auto& it2: it) { \ cerr << it2 << ' '; \ } \ } \ cerr << '\n'; \ } while (0) #else #define EPRINTF(fmt, ...) #define DUMP(var) #define DUMPITERABLE(i) #define DUMPITERABLE2(i) #endif // util // [l, r) #define REP(i, n) for (int i = 0; i < (n); i++) #define RREP(i, n) for (int i = (n)-1; i >= 0; i--) #define FOR(i, l, r) for (int i = (l); i < (r); i++) #define RFOR(i, r, l) for (int i = (r)-1; i >= (l); i--) #define REPLL(i, n) for (ll i = 0; i < (n); i++) #define RREPLL(i, n) for (ll i = (n)-1; i >= 0; i--) #define FORLL(i, l, r) for (ll i = (l); i < (r); i++) #define RFORLL(i, r, l) for (ll i = (r)-1; i >= (l); i--) int main(void) { INIT; // 1<= n <= 10^10 // 0 <= x < m <= 10^5 VAR(ll, n, x, m); ll an = x, n1, n2, sum = 0LL, lsum, rest; vector<ll> a; a.push_back(an); REPLL(i, m) { // a[1], ..., a[m] an = (ll)pow(an,2) % m; a.push_back(an); } n2 = m; REPLL(i, m) // a[0], ..., a[m-1] if (a[i] == a[m]) n1 = i; if (n <= m+1) { sum += accumulate(a.begin(), a.begin()+n, 0LL); } else { sum += accumulate(a.begin(), a.begin()+n1, 0LL); lsum = accumulate(a.begin()+n1, a.begin()+n2, 0LL); rest = accumulate(a.begin()+n1, a.begin()+n1+(n-n1)%(n2-n1), 0LL); sum += lsum*((n-n1)/(n2-n1)); sum += rest; } OUT(sum); ENDL; return 0; }