cyan's blog

しょーもない事しか書いていません

ABC179 E問題: 数列の各項の剰余における周期性

atcoder.jp

詳細は上記リンクより参照.

問題の概要:

1 <= N <= 10^10
0 <= X < M <= 10^5
N, X, Mはそれぞれ整数

という制約下において,

初期値: A_1 = X
漸化式: A_(n+1) = (A_n)^2 mod M

で定義される数列Aについて,

A_1からA_Nまでの各項の和を求めるというもの.

この手の問題によく触れている人ならば数列の定義を見た瞬間, 項に周期性があることは想像できる.

周期性を持つ根拠を簡単に説明すると, 以下のようになる.

数列Aの定義より, Aの各項は高々M通り(0, ... , M-1).

よって鳩の巣原理より, A_1からA_(M+1)までの各項には同じ値を持つ2項が必ず存在する.

この2項をそれぞれA_(n1), A_(n2) (n1<n2) とする.

ここで, 数列Aの定義より,

A_(n1) = A_(n2)
であるならば, 
A_(n1+1) = (A_(n1))^2 mod M = (A_(n2))^2 mod M = A_(n2+1)
同様にして,
A_(n1+2) = A_(n2+2)
A_(n1+3) = A_(n2+3)
...

よって, n1個目以上である数列Aの各項において

A_n = A_(n+n2-n1) (n>=n1)

である. これは数列Aの各項がn1個目以上においてループし, その周期はn2-n1であることを意味している.

実装方法については, 愚直に最大M+1回 数列Aの各項を計算しループを検出する, ダブリングを用いる, などが考えられるが, 制約上Mは最大10^5であるため計算量には余裕があると判断したことから, シンプルに実装することを選んだ.

ループの周期は自由に選んで問題ないので, 数列Aの各項をM+1個計算し, A_(M+1)のインデックスをn2, A_(M+1)と一致する値を持つ最も後ろの項のインデックスをn1とし, 数列A

  • ループ開始前
  • ループ部
  • ループ部が終了するまえにN項目に達する部分

の3通りに分けて計算した.

#include <algorithm>
#include <bitset>
#include <cassert>
#include <cctype>
#include <chrono>
#include <climits>
#include <cmath>
#include <complex>
#include <cstdio>
#include <cstring>
#include <deque>
#include <fstream>
#include <functional>
#include <iomanip>
#include <iostream>
#include <iterator>
#include <list>
#include <map>
#include <numeric>
#include <queue>
#include <random>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <utility>
#include <valarray>
#include <unordered_map>
#include <unordered_set>
#include <vector>
// namespace
using namespace std;
// type
using ll = long long;
using ull = unsigned long long;
using ld = long double;
using PAIR = pair<int, int>;
using PAIRLL = pair<ll, ll>;
// input
#define INIT              ({ ios::sync_with_stdio(false); cin.tie(0); })
#define VAR(type, ...)    type __VA_ARGS__; var_scan(__VA_ARGS__);
template<typename T>
static void var_scan(T& t)
{
    cin >> t;
}
template<typename T, typename... R>
static void var_scan(T& t, R&... rest)
{
    cin >> t;
    var_scan(rest...);
}
// output
#define OUT(var)          cout << (var)
#define SPC               cout << ' '
#define TAB               cout << '\t'
#define BR                cout << '\n'
#define ENDL              cout << endl
#define FLUSH             cout << flush
// debug
#ifdef DEBUG
#define EPRINTF(fmt, ...) fprintf(stderr, fmt, __VA_ARGS__)
#define DUMP(var)         cerr << #var << '\t' << (var) << '\n'
#define DUMPITERABLE(i) do {                            \
        cerr << #i << '\t';                             \
        for (const auto& it : i) { cerr << it << ' '; } \
        cerr << '\n';                                   \
} while (0)
#define DUMPITERABLE2(i2) do {                          \
        cerr << #i2 << '\t';                            \
        for (const auto& it : i2) {                     \
            for (const auto& it2: it) {                 \
                cerr << it2 << ' ';                     \
            }                                           \
        }                                               \
        cerr << '\n';                                   \
} while (0)
#else
#define EPRINTF(fmt, ...)
#define DUMP(var)
#define DUMPITERABLE(i)
#define DUMPITERABLE2(i)
#endif
// util
// [l, r)
#define REP(i, n)       for (int i = 0; i < (n); i++)
#define RREP(i, n)      for (int i = (n)-1; i >= 0; i--)
#define FOR(i, l, r)    for (int i = (l); i < (r); i++)
#define RFOR(i, r, l)   for (int i = (r)-1; i >= (l); i--)
#define REPLL(i, n)     for (ll i = 0; i < (n); i++)
#define RREPLL(i, n)    for (ll i = (n)-1; i >= 0; i--)
#define FORLL(i, l, r)  for (ll i = (l); i < (r); i++)
#define RFORLL(i, r, l) for (ll i = (r)-1; i >= (l); i--)

int main(void)
{
    INIT;

    // 1<= n <= 10^10
    // 0 <= x < m <= 10^5

    VAR(ll, n, x, m);

    ll an = x, n1, n2, sum = 0LL, lsum, rest;
    vector<ll> a;

    a.push_back(an);

    REPLL(i, m) { // a[1], ..., a[m]
        an = (ll)pow(an,2) % m;
        a.push_back(an);
    }

    n2 = m;

    REPLL(i, m) // a[0], ..., a[m-1]
        if (a[i] == a[m])
            n1 = i;

    if (n <= m+1) {
        sum += accumulate(a.begin(), a.begin()+n, 0LL);
    } else {
        sum += accumulate(a.begin(), a.begin()+n1, 0LL);
        lsum = accumulate(a.begin()+n1, a.begin()+n2, 0LL);
        rest = accumulate(a.begin()+n1, a.begin()+n1+(n-n1)%(n2-n1), 0LL);
        sum += lsum*((n-n1)/(n2-n1));
        sum += rest;
    }

    OUT(sum);

    ENDL;

    return 0;
}