Bitcoin Core  27.99.0
P2P Digital Currency
sketch_impl.h
Go to the documentation of this file.
1 /**********************************************************************
2  * Copyright (c) 2018 Pieter Wuille, Greg Maxwell, Gleb Naumenko *
3  * Distributed under the MIT software license, see the accompanying *
4  * file LICENSE or http://www.opensource.org/licenses/mit-license.php.*
5  **********************************************************************/
6 
7 #ifndef _MINISKETCH_SKETCH_IMPL_H_
8 #define _MINISKETCH_SKETCH_IMPL_H_
9 
10 #include <random>
11 
12 #include "util.h"
13 #include "sketch.h"
14 #include "int_utils.h"
15 
17 template<typename F>
18 void PolyMod(const std::vector<typename F::Elem>& mod, std::vector<typename F::Elem>& val, const F& field) {
19  size_t modsize = mod.size();
20  CHECK_SAFE(modsize > 0 && mod.back() == 1);
21  if (val.size() < modsize) return;
22  CHECK_SAFE(val.back() != 0);
23  while (val.size() >= modsize) {
24  auto term = val.back();
25  val.pop_back();
26  if (term != 0) {
27  typename F::Multiplier mul(field, term);
28  for (size_t x = 0; x < mod.size() - 1; ++x) {
29  val[val.size() - modsize + 1 + x] ^= mul(mod[x]);
30  }
31  }
32  }
33  while (val.size() > 0 && val.back() == 0) val.pop_back();
34 }
35 
37 template<typename F>
38 void DivMod(const std::vector<typename F::Elem>& mod, std::vector<typename F::Elem>& val, std::vector<typename F::Elem>& div, const F& field) {
39  size_t modsize = mod.size();
40  CHECK_SAFE(mod.size() > 0 && mod.back() == 1);
41  if (val.size() < mod.size()) {
42  div.clear();
43  return;
44  }
45  CHECK_SAFE(val.back() != 0);
46  div.resize(val.size() - mod.size() + 1);
47  while (val.size() >= modsize) {
48  auto term = val.back();
49  div[val.size() - modsize] = term;
50  val.pop_back();
51  if (term != 0) {
52  typename F::Multiplier mul(field, term);
53  for (size_t x = 0; x < mod.size() - 1; ++x) {
54  val[val.size() - modsize + 1 + x] ^= mul(mod[x]);
55  }
56  }
57  }
58 }
59 
61 template<typename F>
62 typename F::Elem MakeMonic(std::vector<typename F::Elem>& a, const F& field) {
63  CHECK_SAFE(a.back() != 0);
64  if (a.back() == 1) return 0;
65  auto inv = field.Inv(a.back());
66  typename F::Multiplier mul(field, inv);
67  a.back() = 1;
68  for (size_t i = 0; i < a.size() - 1; ++i) {
69  a[i] = mul(a[i]);
70  }
71  return inv;
72 }
73 
75 template<typename F>
76 void GCD(std::vector<typename F::Elem>& a, std::vector<typename F::Elem>& b, const F& field) {
77  if (a.size() < b.size()) std::swap(a, b);
78  while (b.size() > 0) {
79  if (b.size() == 1) {
80  a.resize(1);
81  a[0] = 1;
82  return;
83  }
84  MakeMonic(b, field);
85  PolyMod(b, a, field);
86  std::swap(a, b);
87  }
88 }
89 
91 template<typename F>
92 void Sqr(std::vector<typename F::Elem>& poly, const F& field) {
93  if (poly.size() == 0) return;
94  poly.resize(poly.size() * 2 - 1);
95  for (size_t i = 0; i < poly.size(); ++i) {
96  auto x = poly.size() - i - 1;
97  poly[x] = (x & 1) ? 0 : field.Sqr(poly[x / 2]);
98  }
99 }
100 
102 template<typename F>
103 void TraceMod(const std::vector<typename F::Elem>& mod, std::vector<typename F::Elem>& out, const typename F::Elem& param, const F& field) {
104  out.reserve(mod.size() * 2);
105  out.resize(2);
106  out[0] = 0;
107  out[1] = param;
108 
109  for (int i = 0; i < field.Bits() - 1; ++i) {
110  Sqr(out, field);
111  if (out.size() < 2) out.resize(2);
112  out[1] = param;
113  PolyMod(mod, out, field);
114  }
115 }
116 
128 template<typename F>
129 bool RecFindRoots(std::vector<std::vector<typename F::Elem>>& stack, size_t pos, std::vector<typename F::Elem>& roots, bool fully_factorizable, int depth, typename F::Elem randv, const F& field) {
130  auto& ppoly = stack[pos];
131  // We assert ppoly.size() > 1 (instead of just ppoly.size() > 0) to additionally exclude
132  // constants polynomials because
133  // - ppoly is not constant initially (this is ensured by FindRoots()), and
134  // - we never recurse on a constant polynomial.
135  CHECK_SAFE(ppoly.size() > 1 && ppoly.back() == 1);
136  /* 1st degree input: constant term is the root. */
137  if (ppoly.size() == 2) {
138  roots.push_back(ppoly[0]);
139  return true;
140  }
141  /* 2nd degree input: use direct quadratic solver. */
142  if (ppoly.size() == 3) {
143  CHECK_RETURN(ppoly[1] != 0, false); // Equations of the form (x^2 + a) have two identical solutions; contradicts square-free assumption. */
144  auto input = field.Mul(ppoly[0], field.Sqr(field.Inv(ppoly[1])));
145  auto root = field.Qrt(input);
146  if ((field.Sqr(root) ^ root) != input) {
147  CHECK_SAFE(!fully_factorizable);
148  return false; // No root found.
149  }
150  auto sol = field.Mul(root, ppoly[1]);
151  roots.push_back(sol);
152  roots.push_back(sol ^ ppoly[1]);
153  return true;
154  }
155  /* 3rd degree input and more: recurse further. */
156  if (pos + 3 > stack.size()) {
157  // Allocate memory if necessary.
158  stack.resize((pos + 3) * 2);
159  }
160  auto& poly = stack[pos];
161  auto& tmp = stack[pos + 1];
162  auto& trace = stack[pos + 2];
163  trace.clear();
164  tmp.clear();
165  for (int iter = 0;; ++iter) {
166  // Compute the polynomial (trace(x*randv) mod poly(x)) symbolically,
167  // and put the result in `trace`.
168  TraceMod(poly, trace, randv, field);
169 
170  if (iter >= 1 && !fully_factorizable) {
171  // If the polynomial cannot be factorized completely (it has an
172  // irreducible factor of degree higher than 1), we want to avoid
173  // the case where this is only detected after trying all BITS
174  // independent split attempts fail (see the assert below).
175  //
176  // Observe that if we call y = randv*x, it is true that:
177  //
178  // trace = y + y^2 + y^4 + y^8 + ... y^(FIELDSIZE/2) mod poly
179  //
180  // Due to the Frobenius endomorphism, this means:
181  //
182  // trace^2 = y^2 + y^4 + y^8 + ... + y^FIELDSIZE mod poly
183  //
184  // Or, adding them up:
185  //
186  // trace + trace^2 = y + y^FIELDSIZE mod poly.
187  // = randv*x + randv^FIELDSIZE*x^FIELDSIZE
188  // = randv*x + randv*x^FIELDSIZE
189  // = randv*(x + x^FIELDSIZE).
190  // (all mod poly)
191  //
192  // x + x^FIELDSIZE is the polynomial which has every field element
193  // as root once. Whenever x + x^FIELDSIZE is multiple of poly,
194  // this means it only has unique first degree factors. The same
195  // holds for its constant multiple randv*(x + x^FIELDSIZE) =
196  // trace + trace^2.
197  //
198  // We use this test to quickly verify whether the polynomial is
199  // fully factorizable after already having computed a trace.
200  // We don't invoke it immediately; only when splitting has failed
201  // at least once, which avoids it for most polynomials that are
202  // fully factorizable (or at least pushes the test down the
203  // recursion to factors which are smaller and thus faster).
204  tmp = trace;
205  Sqr(tmp, field);
206  for (size_t i = 0; i < trace.size(); ++i) {
207  tmp[i] ^= trace[i];
208  }
209  while (tmp.size() && tmp.back() == 0) tmp.pop_back();
210  PolyMod(poly, tmp, field);
211 
212  // Whenever the test fails, we can immediately abort the root
213  // finding. Whenever it succeeds, we can remember and pass down
214  // the information that it is in fact fully factorizable, avoiding
215  // the need to run the test again.
216  if (tmp.size() != 0) return false;
217  fully_factorizable = true;
218  }
219 
220  if (fully_factorizable) {
221  // Every successful iteration of this algorithm splits the input
222  // polynomial further into buckets, each corresponding to a subset
223  // of 2^(BITS-depth) roots. If after depth splits the degree of
224  // the polynomial is >= 2^(BITS-depth), something is wrong.
225  CHECK_RETURN(field.Bits() - depth >= std::numeric_limits<decltype(poly.size())>::digits ||
226  (poly.size() - 2) >> (field.Bits() - depth) == 0, false);
227  }
228 
229  depth++;
230  // In every iteration we multiply randv by 2. As a result, the set
231  // of randv values forms a GF(2)-linearly independent basis of splits.
232  randv = field.Mul2(randv);
233  tmp = poly;
234  GCD(trace, tmp, field);
235  if (trace.size() != poly.size() && trace.size() > 1) break;
236  }
237  MakeMonic(trace, field);
238  DivMod(trace, poly, tmp, field);
239  // At this point, the stack looks like [... (poly) tmp trace], and we want to recursively
240  // find roots of trace and tmp (= poly/trace). As we don't care about poly anymore, move
241  // trace into its position first.
242  std::swap(poly, trace);
243  // Now the stack is [... (trace) tmp ...]. First we factor tmp (at pos = pos+1), and then
244  // we factor trace (at pos = pos).
245  if (!RecFindRoots(stack, pos + 1, roots, fully_factorizable, depth, randv, field)) return false;
246  // The stack position pos contains trace, the polynomial with all of poly's roots which (after
247  // multiplication with randv) have trace 0. This is never the case for irreducible factors
248  // (which always end up in tmp), so we can set fully_factorizable to true when recursing.
249  bool ret = RecFindRoots(stack, pos, roots, true, depth, randv, field);
250  // Because of the above, recursion can never fail here.
251  CHECK_SAFE(ret);
252  return ret;
253 }
254 
263 template<typename F>
264 std::vector<typename F::Elem> FindRoots(const std::vector<typename F::Elem>& poly, typename F::Elem basis, const F& field) {
265  std::vector<typename F::Elem> roots;
266  CHECK_RETURN(poly.size() != 0, {});
267  CHECK_RETURN(basis != 0, {});
268  if (poly.size() == 1) return roots; // No roots when the polynomial is a constant.
269  roots.reserve(poly.size() - 1);
270  std::vector<std::vector<typename F::Elem>> stack = {poly};
271 
272  // Invoke the recursive factorization algorithm.
273  if (!RecFindRoots(stack, 0, roots, false, 0, basis, field)) {
274  // Not fully factorizable.
275  return {};
276  }
277  CHECK_RETURN(poly.size() - 1 == roots.size(), {});
278  return roots;
279 }
280 
281 template<typename F>
282 std::vector<typename F::Elem> BerlekampMassey(const std::vector<typename F::Elem>& syndromes, size_t max_degree, const F& field) {
283  std::vector<typename F::Multiplier> table;
284  std::vector<typename F::Elem> current, prev, tmp;
285  current.reserve(syndromes.size() / 2 + 1);
286  prev.reserve(syndromes.size() / 2 + 1);
287  tmp.reserve(syndromes.size() / 2 + 1);
288  current.resize(1);
289  current[0] = 1;
290  prev.resize(1);
291  prev[0] = 1;
292  typename F::Elem b = 1, b_inv = 1;
293  bool b_have_inv = true;
294  table.reserve(syndromes.size());
295 
296  for (size_t n = 0; n != syndromes.size(); ++n) {
297  table.emplace_back(field, syndromes[n]);
298  auto discrepancy = syndromes[n];
299  for (size_t i = 1; i < current.size(); ++i) discrepancy ^= table[n - i](current[i]);
300  if (discrepancy != 0) {
301  int x = static_cast<int>(n + 1 - (current.size() - 1) - (prev.size() - 1));
302  if (!b_have_inv) {
303  b_inv = field.Inv(b);
304  b_have_inv = true;
305  }
306  bool swap = 2 * (current.size() - 1) <= n;
307  if (swap) {
308  if (prev.size() + x - 1 > max_degree) return {}; // We'd exceed maximum degree
309  tmp = current;
310  current.resize(prev.size() + x);
311  }
312  typename F::Multiplier mul(field, field.Mul(discrepancy, b_inv));
313  for (size_t i = 0; i < prev.size(); ++i) current[i + x] ^= mul(prev[i]);
314  if (swap) {
315  std::swap(prev, tmp);
316  b = discrepancy;
317  b_have_inv = false;
318  }
319  }
320  }
321  CHECK_RETURN(current.size() && current.back() != 0, {});
322  return current;
323 }
324 
325 template<typename F>
326 std::vector<typename F::Elem> ReconstructAllSyndromes(const std::vector<typename F::Elem>& odd_syndromes, const F& field) {
327  std::vector<typename F::Elem> all_syndromes;
328  all_syndromes.resize(odd_syndromes.size() * 2);
329  for (size_t i = 0; i < odd_syndromes.size(); ++i) {
330  all_syndromes[i * 2] = odd_syndromes[i];
331  all_syndromes[i * 2 + 1] = field.Sqr(all_syndromes[i]);
332  }
333  return all_syndromes;
334 }
335 
336 template<typename F>
337 void AddToOddSyndromes(std::vector<typename F::Elem>& osyndromes, typename F::Elem data, const F& field) {
338  auto sqr = field.Sqr(data);
339  typename F::Multiplier mul(field, sqr);
340  for (auto& osyndrome : osyndromes) {
341  osyndrome ^= data;
342  data = mul(data);
343  }
344 }
345 
346 template<typename F>
347 std::vector<typename F::Elem> FullDecode(const std::vector<typename F::Elem>& osyndromes, const F& field) {
348  auto asyndromes = ReconstructAllSyndromes<typename F::Elem>(osyndromes, field);
349  auto poly = BerlekampMassey(asyndromes, field);
350  std::reverse(poly.begin(), poly.end());
351  return FindRoots(poly, field);
352 }
353 
354 template<typename F>
355 class SketchImpl final : public Sketch
356 {
357  const F m_field;
358  std::vector<typename F::Elem> m_syndromes;
359  typename F::Elem m_basis;
360 
361 public:
362  template<typename... Args>
363  SketchImpl(int implementation, int bits, const Args&... args) : Sketch(implementation, bits), m_field(args...) {
364  std::random_device rng;
365  std::uniform_int_distribution<uint64_t> dist;
366  m_basis = m_field.FromSeed(dist(rng));
367  }
368 
369  size_t Syndromes() const override { return m_syndromes.size(); }
370  void Init(size_t count) override { m_syndromes.assign(count, 0); }
371 
372  void Add(uint64_t val) override
373  {
374  auto elem = m_field.FromUint64(val);
376  }
377 
378  void Serialize(unsigned char* ptr) const override
379  {
380  BitWriter writer(ptr);
381  for (const auto& val : m_syndromes) {
382  m_field.Serialize(writer, val);
383  }
384  writer.Flush();
385  }
386 
387  void Deserialize(const unsigned char* ptr) override
388  {
389  BitReader reader(ptr);
390  for (auto& val : m_syndromes) {
391  val = m_field.Deserialize(reader);
392  }
393  }
394 
395  int Decode(int max_count, uint64_t* out) const override
396  {
397  auto all_syndromes = ReconstructAllSyndromes(m_syndromes, m_field);
398  auto poly = BerlekampMassey(all_syndromes, max_count, m_field);
399  if (poly.size() == 0) return -1;
400  if (poly.size() == 1) return 0;
401  if ((int)poly.size() > 1 + max_count) return -1;
402  std::reverse(poly.begin(), poly.end());
403  auto roots = FindRoots(poly, m_basis, m_field);
404  if (roots.size() == 0) return -1;
405 
406  for (const auto& root : roots) {
407  *(out++) = m_field.ToUint64(root);
408  }
409  return static_cast<int>(roots.size());
410  }
411 
412  size_t Merge(const Sketch* other_sketch) override
413  {
414  // Sad cast. This is safe only because the caller code in minisketch.cpp checks
415  // that implementation and field size match.
416  const SketchImpl* other = static_cast<const SketchImpl*>(other_sketch);
417  m_syndromes.resize(std::min(m_syndromes.size(), other->m_syndromes.size()));
418  for (size_t i = 0; i < m_syndromes.size(); ++i) {
419  m_syndromes[i] ^= other->m_syndromes[i];
420  }
421  return m_syndromes.size();
422  }
423 
424  void SetSeed(uint64_t seed) override
425  {
426  if (seed == (uint64_t)-1) {
427  m_basis = 1;
428  } else {
429  m_basis = m_field.FromSeed(seed);
430  }
431  }
432 };
433 
434 #endif
int ret
ArgsManager & args
Definition: bitcoind.cpp:270
void Flush()
Definition: int_utils.h:95
Abstract class for internal representation of a minisketch object.
Definition: sketch.h:15
const F m_field
Definition: sketch_impl.h:357
F::Elem m_basis
Definition: sketch_impl.h:359
SketchImpl(int implementation, int bits, const Args &... args)
Definition: sketch_impl.h:363
void Serialize(unsigned char *ptr) const override
Definition: sketch_impl.h:378
size_t Merge(const Sketch *other_sketch) override
Definition: sketch_impl.h:412
int Decode(int max_count, uint64_t *out) const override
Definition: sketch_impl.h:395
void Add(uint64_t val) override
Definition: sketch_impl.h:372
void Init(size_t count) override
Definition: sketch_impl.h:370
void Deserialize(const unsigned char *ptr) override
Definition: sketch_impl.h:387
size_t Syndromes() const override
Definition: sketch_impl.h:369
std::vector< typename F::Elem > m_syndromes
Definition: sketch_impl.h:358
void SetSeed(uint64_t seed) override
Definition: sketch_impl.h:424
#define CHECK_RETURN(cond, rvar)
Check a condition and return on failure in non-verify builds, crash in verify builds.
Definition: util.h:67
#define CHECK_SAFE(cond)
Check macro that does nothing in normal non-verify builds but crashes in verify builds.
Definition: util.h:50
void AddToOddSyndromes(std::vector< typename F::Elem > &osyndromes, typename F::Elem data, const F &field)
Definition: sketch_impl.h:337
void TraceMod(const std::vector< typename F::Elem > &mod, std::vector< typename F::Elem > &out, const typename F::Elem &param, const F &field)
Compute the trace map of (param*x) modulo mod, putting the result in out.
Definition: sketch_impl.h:103
bool RecFindRoots(std::vector< std::vector< typename F::Elem >> &stack, size_t pos, std::vector< typename F::Elem > &roots, bool fully_factorizable, int depth, typename F::Elem randv, const F &field)
One step of the root finding algorithm; finds roots of stack[pos] and adds them to roots.
Definition: sketch_impl.h:129
void PolyMod(const std::vector< typename F::Elem > &mod, std::vector< typename F::Elem > &val, const F &field)
Compute the remainder of a polynomial division of val by mod, putting the result in mod.
Definition: sketch_impl.h:18
void GCD(std::vector< typename F::Elem > &a, std::vector< typename F::Elem > &b, const F &field)
Compute the GCD of two polynomials, putting the result in a.
Definition: sketch_impl.h:76
F::Elem MakeMonic(std::vector< typename F::Elem > &a, const F &field)
Make a polynomial monic.
Definition: sketch_impl.h:62
void Sqr(std::vector< typename F::Elem > &poly, const F &field)
Square a polynomial.
Definition: sketch_impl.h:92
std::vector< typename F::Elem > ReconstructAllSyndromes(const std::vector< typename F::Elem > &odd_syndromes, const F &field)
Definition: sketch_impl.h:326
std::vector< typename F::Elem > BerlekampMassey(const std::vector< typename F::Elem > &syndromes, size_t max_degree, const F &field)
Definition: sketch_impl.h:282
void DivMod(const std::vector< typename F::Elem > &mod, std::vector< typename F::Elem > &val, std::vector< typename F::Elem > &div, const F &field)
Compute the quotient of a polynomial division of val by mod, putting the quotient in div and the rema...
Definition: sketch_impl.h:38
std::vector< typename F::Elem > FullDecode(const std::vector< typename F::Elem > &osyndromes, const F &field)
Definition: sketch_impl.h:347
std::vector< typename F::Elem > FindRoots(const std::vector< typename F::Elem > &poly, typename F::Elem basis, const F &field)
Returns the roots of a fully factorizable polynomial.
Definition: sketch_impl.h:264
static int count