RangeCoderBitTree.h 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. // Compress/RangeCoder/RangeCoderBitTree.h
  2. #ifndef __COMPRESS_RANGECODER_BIT_TREE_H
  3. #define __COMPRESS_RANGECODER_BIT_TREE_H
  4. #include "RangeCoderBit.h"
  5. #include "RangeCoderOpt.h"
  6. namespace NCompress {
  7. namespace NRangeCoder {
  8. template <int numMoveBits, int NumBitLevels>
  9. class CBitTreeEncoder
  10. {
  11. CBitEncoder<numMoveBits> Models[1 << NumBitLevels];
  12. public:
  13. void Init()
  14. {
  15. for(UInt32 i = 1; i < (1 << NumBitLevels); i++)
  16. Models[i].Init();
  17. }
  18. void Encode(CEncoder *rangeEncoder, UInt32 symbol)
  19. {
  20. UInt32 modelIndex = 1;
  21. for (int bitIndex = NumBitLevels; bitIndex != 0 ;)
  22. {
  23. bitIndex--;
  24. UInt32 bit = (symbol >> bitIndex) & 1;
  25. Models[modelIndex].Encode(rangeEncoder, bit);
  26. modelIndex = (modelIndex << 1) | bit;
  27. }
  28. };
  29. void ReverseEncode(CEncoder *rangeEncoder, UInt32 symbol)
  30. {
  31. UInt32 modelIndex = 1;
  32. for (int i = 0; i < NumBitLevels; i++)
  33. {
  34. UInt32 bit = symbol & 1;
  35. Models[modelIndex].Encode(rangeEncoder, bit);
  36. modelIndex = (modelIndex << 1) | bit;
  37. symbol >>= 1;
  38. }
  39. }
  40. UInt32 GetPrice(UInt32 symbol) const
  41. {
  42. symbol |= (1 << NumBitLevels);
  43. UInt32 price = 0;
  44. while (symbol != 1)
  45. {
  46. price += Models[symbol >> 1].GetPrice(symbol & 1);
  47. symbol >>= 1;
  48. }
  49. return price;
  50. }
  51. UInt32 ReverseGetPrice(UInt32 symbol) const
  52. {
  53. UInt32 price = 0;
  54. UInt32 modelIndex = 1;
  55. for (int i = NumBitLevels; i != 0; i--)
  56. {
  57. UInt32 bit = symbol & 1;
  58. symbol >>= 1;
  59. price += Models[modelIndex].GetPrice(bit);
  60. modelIndex = (modelIndex << 1) | bit;
  61. }
  62. return price;
  63. }
  64. };
  65. template <int numMoveBits, int NumBitLevels>
  66. class CBitTreeDecoder
  67. {
  68. CBitDecoder<numMoveBits> Models[1 << NumBitLevels];
  69. public:
  70. void Init()
  71. {
  72. for(UInt32 i = 1; i < (1 << NumBitLevels); i++)
  73. Models[i].Init();
  74. }
  75. UInt32 Decode(CDecoder *rangeDecoder)
  76. {
  77. UInt32 modelIndex = 1;
  78. RC_INIT_VAR
  79. for(int bitIndex = NumBitLevels; bitIndex != 0; bitIndex--)
  80. {
  81. // modelIndex = (modelIndex << 1) + Models[modelIndex].Decode(rangeDecoder);
  82. RC_GETBIT(numMoveBits, Models[modelIndex].Prob, modelIndex)
  83. }
  84. RC_FLUSH_VAR
  85. return modelIndex - (1 << NumBitLevels);
  86. };
  87. UInt32 ReverseDecode(CDecoder *rangeDecoder)
  88. {
  89. UInt32 modelIndex = 1;
  90. UInt32 symbol = 0;
  91. RC_INIT_VAR
  92. for(int bitIndex = 0; bitIndex < NumBitLevels; bitIndex++)
  93. {
  94. // UInt32 bit = Models[modelIndex].Decode(rangeDecoder);
  95. // modelIndex <<= 1;
  96. // modelIndex += bit;
  97. // symbol |= (bit << bitIndex);
  98. RC_GETBIT2(numMoveBits, Models[modelIndex].Prob, modelIndex, ; , symbol |= (1 << bitIndex))
  99. }
  100. RC_FLUSH_VAR
  101. return symbol;
  102. }
  103. };
  104. template <int numMoveBits>
  105. void ReverseBitTreeEncode(CBitEncoder<numMoveBits> *Models,
  106. CEncoder *rangeEncoder, int NumBitLevels, UInt32 symbol)
  107. {
  108. UInt32 modelIndex = 1;
  109. for (int i = 0; i < NumBitLevels; i++)
  110. {
  111. UInt32 bit = symbol & 1;
  112. Models[modelIndex].Encode(rangeEncoder, bit);
  113. modelIndex = (modelIndex << 1) | bit;
  114. symbol >>= 1;
  115. }
  116. }
  117. template <int numMoveBits>
  118. UInt32 ReverseBitTreeGetPrice(CBitEncoder<numMoveBits> *Models,
  119. UInt32 NumBitLevels, UInt32 symbol)
  120. {
  121. UInt32 price = 0;
  122. UInt32 modelIndex = 1;
  123. for (int i = NumBitLevels; i != 0; i--)
  124. {
  125. UInt32 bit = symbol & 1;
  126. symbol >>= 1;
  127. price += Models[modelIndex].GetPrice(bit);
  128. modelIndex = (modelIndex << 1) | bit;
  129. }
  130. return price;
  131. }
  132. template <int numMoveBits>
  133. UInt32 ReverseBitTreeDecode(CBitDecoder<numMoveBits> *Models,
  134. CDecoder *rangeDecoder, int NumBitLevels)
  135. {
  136. UInt32 modelIndex = 1;
  137. UInt32 symbol = 0;
  138. RC_INIT_VAR
  139. for(int bitIndex = 0; bitIndex < NumBitLevels; bitIndex++)
  140. {
  141. // UInt32 bit = Models[modelIndex].Decode(rangeDecoder);
  142. // modelIndex <<= 1;
  143. // modelIndex += bit;
  144. // symbol |= (bit << bitIndex);
  145. RC_GETBIT2(numMoveBits, Models[modelIndex].Prob, modelIndex, ; , symbol |= (1 << bitIndex))
  146. }
  147. RC_FLUSH_VAR
  148. return symbol;
  149. }
  150. }}
  151. #endif