はじめに
この記事はアドベントカレンダー2025 16日目の記事です
コンテスト中の人へ or 結果が早く知りたい人へ
コードを示します。コードテストでも何でもにこれをコピペしたら、正しく動いていることが分かると思います。
#include <iostream>
#include <vector>
using namespace std;
unsigned long long isqrt_aux(int c,unsigned long long n){
if (c == 0){
return 1;
} else {
int k = (c - 1) / 2;
unsigned long long a = isqrt_aux(c / 2, n >> (2*k + 2));
return (a << k) + (n >> (k+2)) / a;
}
}
unsigned long isqrt(unsigned long long n){
if (n == 0){
return 0;
} else {
unsigned long long a = isqrt_aux(( 63 - __builtin_clzll(n)) / 2, n);
return n <= a * a - 1 ? a - 1 : a;
}
}
// ===========================================
int main(){
vector<unsigned long long> sample{0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,100,1000,1024,314159265,4611686018427387903ULL,4611686018427387904ULL,18446744073709551614ULL,18446744073709551615ULL};
for (int i=0;i < sample.size();i++){
unsigned long long s,ss;
s = sample[i];
ss = isqrt(sample[i]);
cout << s << " -> " << ss << "\n";
}
}
導入
最近Pythonからc++に移行しようと頑張っているn3です。ADTでABC397のD問題をc++で解こうとしたときに、isqrtが必要になったので自作してみました。
isqrt(n) という、 という関数、 以下の最大の整数を返す関数があります。例えば、
みたいな関数です。この関数があると、ある数以下の平方数の数え上げや、ある数が平方数かの判定が行えます。
Pythonにはisqrtが組み込み関数で存在するのですが、c++にはisqrtに相当するものは無いので、自分で書く必要があります。二分探索で のコードを書いても良いのですが、Pythonのisqrtはニュートン法を用いて実装されていて、 で求めることができます。
ですよ!速くないですか?!今回はその内部実装とそのコードの正当性の証明を解説します。
使用例
解説の前に、実際にこのコードが正しく動いていることを確認してみましょう。
ニュートン法について
詳細はこのブログを読みましょう[1]:https://trap.jp/post/890/
要点だけをまとめると、ある関数 の解を求めるときに、適当な初期値 を用いて、
という漸化式における、 が の解に収束するというものです。
証明の概要
上のコードを見ると、2つの関数 isqrt_aux(c,n) と isqrt(n) を用いています。isqrt(n) の方は後で解説するとして、主題はisqrt_aux(c,n)です。
こいつは、
となるようなを返します。要するにisqrt_aux(c,n)は、 のいずれか一方を返します。そのため、isqrtの方で適切に確認すれば、求めたい の値を得ます。
isqrt_aux(c,n)は自身を再帰的に呼び出すことで近似値を計算します。
具体的には、与えられた数の上半分を切り出したそのisqrtの近似値に関して、それを初期値としてニュートン法を行って近似値を計算すると言った感じです。言葉だと分かりにくいので、具体的にやってみましょう。(雰囲気だけですが)[2]
を求めてみます。
- のisqrtを求めたい
- 全体の上4桁の のisqrtを求めたい
- 全体の上2桁の のisqrtを求めたい
- これは明らかに、である。
- なら、引数を100倍して である
- 初期値 、 でニュートン法を一回行うと、となる。切り捨てて とする。
- 全体の上4桁の のisqrtを求めたい
- なので、引数を10000倍して である
- 初期値 、 でニュートン法を一回行うと、となる。切り捨てて とする。
こうして得られた、に関して、
が成り立つので、求める値は である。
このようにすることで、isqrtを適切に近似することができます。上の例はあくまで適当な数値なので常に正しい答えが返ってくるとは限りませんが、Pythonのisqrtは、厳密に誤差評価をすることで必ず妥当な答えが帰ってくることを証明できます。
補題
補題1
任意の実数 ,整数 に関して、以下が成り立つ。
補題2
63 - __builtin_clzll(n) は n を二進数で表したときの桁数から1引いたもの、つまり となる。
補足:__builtin_clzll(n)は、ある数を符号なし64bit2進数で表した場合に、「1になっている最も左にあるビット」より左に0になっているビットが何個あるかを数える関数です。
補題3
補題1の3番目の式と補題2より、以下が従う
補題4
にニュートン法を用いるときの漸化式は、
となる。
定理
ここからは、適宜コードを見ながら理解する必要があるので、再掲しておきます。
unsigned long long isqrt_aux(int c,unsigned long long n){
if (c == 0){
return 1;
} else {
int k = (c - 1) / 2;
unsigned long long a = isqrt_aux(c / 2, n >> (2*k + 2));
return (a << k) + (n >> (k+2)) / a;
}
}
unsigned long isqrt(unsigned long long n){
if (n == 0){
return 0;
} else {
unsigned long long a = isqrt_aux(( 63 - __builtin_clzll(n)) / 2, n);
return n <= a * a - 1 ? a - 1 : a;
}
}
定理1
コード中のisqrt_aux(c,n) の c は、 となる。
証明
isqrt(unsigned long long n)で現れる場合(BaseCase):
isqrt(n) のコードの中にある、isqrt_aux(( 63 - __builtin_clzll(n)) / 2, n); に着目する。補題2,3より、( 63 - __builtin_clzll(n)) / 2は であるので、この場合示された。
isqrt_aux(int c,unsigned long long n)で現れる場合:
扱いやすさのため、このスコープ内で新たに定義される変数に を付ける。引数として与えられる には付けていない。
ここで、 であるならば、 であることを示せば、帰納的に示される。
ここで、補題1の式4より、
つまり、
また、
つまり、
式と式を、帰納法の仮定 に代入して、
よって、帰納法により、題意は示された。
定理2
とする。isqrt_aux(c,n)の返り値を としたとき、
となる。
コードの解説
再帰的に、 の近似値 が 式 を満たすとしたとき、 の近似値を として、ニュートン法の式
に代入し、これがまた 式 を満たす十分な近似になっているので、これを返している
証明
帰納的に示す。引数の が になるとき、つまり のとき、引数の は となり、コード中のif文より を返す。 これは明らかに式(3)を満たす。
のときに成り立つことを、以下の自然数 で式(3)が成り立つことを仮定して示す。
これは再帰的に isqrt_aux() が呼ばれ、この返り値 として、
となる。[3]
これは、分母を払ってルートを取れば、
となる。
ここで、下記の をこの を用いて、
と定義したとき、
よって、 であり、また 式 より、
ここで、 に関する を用いた評価を得る必要がある。
定理1より 、また、 なので、
よって最左辺/最右辺より、
また、式より、
よって、
以上より、 なので、
ここで、 とすると、補題1の2番目の式より、
よって、返り値としてこれを返すことで、式を満たす。
帰納法より題意は示された。
補足
これらより、先に示したisqrt(n)が正しく動作していることが分かります。
isqrt部分の抜粋
unsigned long isqrt(unsigned long long n){
if (n == 0){
return 0;
} else {
unsigned long long a = isqrt_aux(( 63 - __builtin_clzll(n)) / 2, n);
return n <= a * a - 1 ? a - 1 : a;
}
}
のときは別で処理して、の場合はisqrt_auxが を返すので、どちらが正しいかを場合分けします。
のとき、これは の満たすべき不等式
に反しているので、する必要があります。この処理は三項演算子を使ってシンプルに、n <= a * a - 1 ? a - 1 : aと書くことが出来ます。ここで、n <= a*a-1と <=で判定しているのは、a*a-1とすることで unsigned long long でオーバーフローしないようにするためです。は最大で なので、そのままn < a*aのように2乗すると0になっておかしくなります。そこでこのように計算しています。
実行速度
AtCoderのコードテストで、速度比較をします。
測定項目として、
- の連続値
- の連続値
- の乱数 個
- の乱数 個
を採用しました。どれも の計算をしています。
使用したコード (面倒くさいのでGemini製ですが)
乱数部分のコードだけ張ります。許して。 あと、二分探索は遅いので個別に計算しています。まあお気持ち測定なので、そこまで深く見ないでください (言い訳)。#include <iostream>
#include <vector>
#include <algorithm>
#include <chrono>
#include <random> // 乱数生成用
#include <iomanip>
#include <cmath>
using namespace std::chrono;
using namespace std;
#define ll long long
// ===========================================
// 再帰 + ビット演算版 (Recursive Bitwise)
// ===========================================
unsigned long long isqrt_aux(int c, unsigned long long n){
if (c == 0){
return 1;
} else {
int k = (c - 1) / 2;
unsigned long long a = isqrt_aux(c / 2, n >> (2*k + 2));
return (a << k) + (n >> (k+2)) / a;
}
}
unsigned long isqrt(unsigned long long n){
if (n == 0){
return 0;
} else {
unsigned long long a = isqrt_aux(( 63 - __builtin_clzll(n)) / 2, n);
return n < a * a ? a - 1 : a;
}
}
// ===========================================
// std::sqrt版 (nyaan_isqrt) https://nyaannyaan.github.io/library/math/isqrt.hpp.html
// ===========================================
long long nyaan_isqrt(long long n) {
if (n <= 0) return 0;
long long x = sqrt(n);
while ((x + 1) * (x + 1) <= n) x++;
while (x * x > n) x--;
return x;
}
// ===========================================
// 時間計測用
// ===========================================
inline double get_time_sec(void){
return static_cast<double>(duration_cast<nanoseconds>(steady_clock::now().time_since_epoch()).count())/1000000000;
}
// ===========================================
// テスト実行関数
// ===========================================
void run_test(const string& range_name, const vector<unsigned long long>& inputs) {
int TRIALS = 5;
size_t N = inputs.size();
cout << "--------------------------------------------------" << endl;
cout << "Testing Range: " << range_name << endl;
cout << "Data Size: " << N << " random elements" << endl;
cout << "--------------------------------------------------" << endl;
// --- Test 1: isqrt ---
cout << "[ isqrt (Recursive Bitwise) ]" << endl;
double total_time_isqrt = 0;
for(int t = 0; t < TRIALS; t++){
unsigned long long sum = 0;
double start = get_time_sec();
// 事前に生成した乱数配列に対して実行
for(size_t i = 0; i < N; i++){
sum += isqrt(inputs[i]);
}
double end = get_time_sec();
double time_ms = 1000 * (end - start);
total_time_isqrt += time_ms;
cout << " Run " << t + 1 << ": CheckSum = " << sum << ", Time = " << fixed << setprecision(3) << time_ms << " ms" << endl;
}
cout << " >> Average: " << total_time_isqrt / TRIALS << " ms" << endl << endl;
// --- Test 2: nyaan_isqrt ---
cout << "[ nyaan_isqrt (std::sqrt based) ]" << endl;
double total_time_nyaan = 0;
for(int t = 0; t < TRIALS; t++){
unsigned long long sum = 0;
double start = get_time_sec();
for(size_t i = 0; i < N; i++){
sum += nyaan_isqrt(inputs[i]);
}
double end = get_time_sec();
double time_ms = 1000 * (end - start);
total_time_nyaan += time_ms;
cout << " Run " << t + 1 << ": CheckSum = " << sum << ", Time = " << fixed << setprecision(3) << time_ms << " ms" << endl;
}
cout << " >> Average: " << total_time_nyaan / TRIALS << " ms" << endl << endl;
}
int main(){
const size_t NUM_ELEMENTS = 100000000; // 10^8
vector<unsigned long long> inputs(NUM_ELEMENTS);
// 高速な乱数生成器 (メルセンヌ・ツイスタ 64bit)
mt19937_64 rng(1333);
/*
// ------------------------------------------
// パターン 1: [1, 10^9) の乱数
// ------------------------------------------
{
cout << "Generating random numbers for pattern 1..." << endl;
uniform_int_distribution<unsigned long long> dist(1, 1000000000ULL - 1);
for(size_t i = 0; i < NUM_ELEMENTS; i++) inputs[i] = dist(rng);
run_test("[ 1, 10^9 )", inputs);
}
*/
// ------------------------------------------
// パターン 2: [10^12, 2*10^12) の乱数
// ------------------------------------------
{
cout << "Generating random numbers for pattern 2..." << endl;
// 10^12 から 2*10^12 までの範囲 (大きな桁のテスト)
uniform_int_distribution<unsigned long long> dist(1000000000000ULL, 2000000000000ULL - 1);
for(size_t i = 0; i < NUM_ELEMENTS; i++) inputs[i] = dist(rng);
run_test("[ 10^12, 2*10^12 )", inputs);
}
return 0;
}
実行結果
| [ms] | 二分探索 | ニュートン法 |
|---|---|---|
| 連続: | 3305 | 861 |
| 連続: | 5369 | 126 |
| ランダム: | 4174 | 1153 |
| ランダム: | 5458 | 1477 |
考察
ランダムケースを見るに、「ニュートン法の方が~倍ほど高速である」と言えます。
また、連続ケースの の場合、とても速くなっています。自分では分からなかったのでGeminiに聞いたところ、コンパイラが再帰の結果をキャッシュしてるから速いと言われました。ランダムケースでは最適化が掛からなかったので、おそらく正しいと思われます。
ニュートン法は「最速」であるか?
ここまでダラダラと書いて来ましたが、別にニュートン法は、二分探索より速いだけで、最速ではありません。ご存知かもしれませんが、以下のようにする方が圧倒的に速いです。
Nyaan's Library: https://nyaannyaan.github.io/library/math/isqrt.hpp.html
要するに、「組み込み関数のsqrtを切り捨てて、誤差を調整する」という手法を取っています。これを含めた実行時間の表はこのようになります。
| [ms] | 二分探索 | ニュートン法 | Nyaan_isqrt |
|---|---|---|---|
| 連続: | 3305 | 861 | 194 |
| 連続: | 5369 | 126 | 179 |
| ランダム: | 4174 | 1153 | 218 |
| ランダム: | 5458 | 1477 | 216 |
全然勝ててませんね...ランダムケースだと 倍の差を付けられています。 唯一、連続 のケースで勝っています。これは再帰関数でキャッシュが残る夕いつの利点を活かした勝利でしょう。
ただし、実用的には連続区間なら商列挙のようなことをすれば で計算できるので、正直あんまり意味はないですね...
結論
CPU命令速い!!!!!!
でもニュートン法も頑張ってるので許してください。俺はニュートン法を使います()
参考
isqrtの内部実装:
二分探索isqrtの実装:二分探索isqrtとして測定に使わせていただきました(intをllに変更などしましたが)。ビット演算を巧妙に使ってて自分には理解出来ていません(すごい) ← おい
Nyaan's Library の isqrt:やっぱCPU命令にはかなわんのや...
(Youtube) One second to compute as many square roots as I can:海外ニキの動画、このブログの上位互換です
終わりに
ここまで見ていただきありがとうございました(自分のメモを兼ねて書いたブログなので、雑な部分を多いと思います)。「わざわざ証明なんか見なくても動けばいいだろ!」と思うかもしれませんが、個人的には証明を終えて楽しかったです。