ABC179 E問題: 数列の各項の剰余における周期性
詳細は上記リンクより参照.
問題の概要:
1 <= N <= 10^10 0 <= X < M <= 10^5 N, X, Mはそれぞれ整数
という制約下において,
初期値: A_1 = X 漸化式: A_(n+1) = (A_n)^2 mod M
で定義される数列Aについて,
A_1からA_Nまでの各項の和を求めるというもの.
この手の問題によく触れている人ならば数列の定義を見た瞬間, 項に周期性があることは想像できる.
周期性を持つ根拠を簡単に説明すると, 以下のようになる.
数列Aの定義より, Aの各項は高々M通り(0, ... , M-1).
よって鳩の巣原理より, A_1からA_(M+1)までの各項には同じ値を持つ2項が必ず存在する.
この2項をそれぞれA_(n1), A_(n2) (n1<n2) とする.
ここで, 数列Aの定義より,
A_(n1) = A_(n2) であるならば, A_(n1+1) = (A_(n1))^2 mod M = (A_(n2))^2 mod M = A_(n2+1) 同様にして, A_(n1+2) = A_(n2+2) A_(n1+3) = A_(n2+3) ...
よって, n1個目以上である数列Aの各項において
A_n = A_(n+n2-n1) (n>=n1)
である. これは数列Aの各項がn1個目以上においてループし, その周期はn2-n1であることを意味している.
実装方法については, 愚直に最大M+1回 数列Aの各項を計算しループを検出する, ダブリングを用いる, などが考えられるが, 制約上Mは最大10^5であるため計算量には余裕があると判断したことから, シンプルに実装することを選んだ.
ループの周期は自由に選んで問題ないので, 数列Aの各項をM+1個計算し, A_(M+1)のインデックスをn2, A_(M+1)と一致する値を持つ最も後ろの項のインデックスをn1とし, 数列Aを
- ループ開始前
- ループ部
- ループ部が終了するまえに
N項目に達する部分
の3通りに分けて計算した.
#include <algorithm>
#include <bitset>
#include <cassert>
#include <cctype>
#include <chrono>
#include <climits>
#include <cmath>
#include <complex>
#include <cstdio>
#include <cstring>
#include <deque>
#include <fstream>
#include <functional>
#include <iomanip>
#include <iostream>
#include <iterator>
#include <list>
#include <map>
#include <numeric>
#include <queue>
#include <random>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <utility>
#include <valarray>
#include <unordered_map>
#include <unordered_set>
#include <vector>
// namespace
using namespace std;
// type
using ll = long long;
using ull = unsigned long long;
using ld = long double;
using PAIR = pair<int, int>;
using PAIRLL = pair<ll, ll>;
// input
#define INIT ({ ios::sync_with_stdio(false); cin.tie(0); })
#define VAR(type, ...) type __VA_ARGS__; var_scan(__VA_ARGS__);
template<typename T>
static void var_scan(T& t)
{
cin >> t;
}
template<typename T, typename... R>
static void var_scan(T& t, R&... rest)
{
cin >> t;
var_scan(rest...);
}
// output
#define OUT(var) cout << (var)
#define SPC cout << ' '
#define TAB cout << '\t'
#define BR cout << '\n'
#define ENDL cout << endl
#define FLUSH cout << flush
// debug
#ifdef DEBUG
#define EPRINTF(fmt, ...) fprintf(stderr, fmt, __VA_ARGS__)
#define DUMP(var) cerr << #var << '\t' << (var) << '\n'
#define DUMPITERABLE(i) do { \
cerr << #i << '\t'; \
for (const auto& it : i) { cerr << it << ' '; } \
cerr << '\n'; \
} while (0)
#define DUMPITERABLE2(i2) do { \
cerr << #i2 << '\t'; \
for (const auto& it : i2) { \
for (const auto& it2: it) { \
cerr << it2 << ' '; \
} \
} \
cerr << '\n'; \
} while (0)
#else
#define EPRINTF(fmt, ...)
#define DUMP(var)
#define DUMPITERABLE(i)
#define DUMPITERABLE2(i)
#endif
// util
// [l, r)
#define REP(i, n) for (int i = 0; i < (n); i++)
#define RREP(i, n) for (int i = (n)-1; i >= 0; i--)
#define FOR(i, l, r) for (int i = (l); i < (r); i++)
#define RFOR(i, r, l) for (int i = (r)-1; i >= (l); i--)
#define REPLL(i, n) for (ll i = 0; i < (n); i++)
#define RREPLL(i, n) for (ll i = (n)-1; i >= 0; i--)
#define FORLL(i, l, r) for (ll i = (l); i < (r); i++)
#define RFORLL(i, r, l) for (ll i = (r)-1; i >= (l); i--)
int main(void)
{
INIT;
// 1<= n <= 10^10
// 0 <= x < m <= 10^5
VAR(ll, n, x, m);
ll an = x, n1, n2, sum = 0LL, lsum, rest;
vector<ll> a;
a.push_back(an);
REPLL(i, m) { // a[1], ..., a[m]
an = (ll)pow(an,2) % m;
a.push_back(an);
}
n2 = m;
REPLL(i, m) // a[0], ..., a[m-1]
if (a[i] == a[m])
n1 = i;
if (n <= m+1) {
sum += accumulate(a.begin(), a.begin()+n, 0LL);
} else {
sum += accumulate(a.begin(), a.begin()+n1, 0LL);
lsum = accumulate(a.begin()+n1, a.begin()+n2, 0LL);
rest = accumulate(a.begin()+n1, a.begin()+n1+(n-n1)%(n2-n1), 0LL);
sum += lsum*((n-n1)/(n2-n1));
sum += rest;
}
OUT(sum);
ENDL;
return 0;
}