diff --git a/cpp/nth-prime/nth_prime.cpp b/cpp/nth-prime/nth_prime.cpp index 762db61..f4089e6 100644 --- a/cpp/nth-prime/nth_prime.cpp +++ b/cpp/nth-prime/nth_prime.cpp @@ -15,31 +15,36 @@ number_t nth(size_t n, size_t sieve_size) if (n < 1) throw domain_error("invalid prime number index"); - vector primes = {2, 3, 5, 7, 11, 13}; + if (n == 1) + return 2; - if (n <= primes.size()) - return primes[n - 1]; + vector primes = {3, 5, 7, 11, 13}; - primes.reserve(n); + if (n - 2 < primes.size()) + return primes[n - 2]; + + primes.reserve(n - 1); if (auto_sieve_size == sieve_size) { sieve_size = min(size_t(32 * 1024), - size_t(n * (log(n) + log(log(n))) + 0.5)); + size_t(0.5*(n * (log(n) + log(log(n)))) + 0.5)); } vector sieve(sieve_size, false); - sieve[0] = sieve[1] = true; + sieve[0] = true; - size_t first_number = 0; + size_t first_number = 1; while (true) { for (auto p : primes) { - size_t reminder = first_number % p; - size_t begin = reminder ? p - reminder : 0; + size_t reminder = (first_number % (2*p)); + size_t begin = reminder > p ? 3*p - reminder : p - reminder; + + begin /= 2; for (size_t i = begin; i < sieve_size; i += p) { @@ -52,10 +57,10 @@ number_t nth(size_t n, size_t sieve_size) if (sieve[i]) continue; - number_t p = first_number + i; + number_t p = first_number + 2*i; primes.push_back(p); - if (primes.size() == n) + if (primes.size() == n - 1) return p; for (size_t j = i; j < sieve_size; j += p) { @@ -64,7 +69,7 @@ number_t nth(size_t n, size_t sieve_size) } } - size_t next_first_number = first_number + sieve_size; + size_t next_first_number = first_number + 2*sieve_size; if (next_first_number < first_number) throw runtime_error("failed to reach prime number");