Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
395 views
in Technique[技术] by (71.8m points)

c++ - AVX2 column population count algorithm over each bit-column separately

For a project I'm working on I need to count the number of set bits per column in ripped PDF image data.

I'm trying to get the total set bit count for each column in the entire PDF job (all pages).

The data, once ripped, is stored in a MemoryMappedFile with no backing file (in memory).

The PDF page dimensions are 13952 pixels x 15125 pixels. The total size of the resulting ripped data can be calculated by multiplying the length (height) of the PDF in pixels by the width in bytes. The ripped data is 1 bit == 1 pixel. So the size of a ripped page in bytes is (13952 / 8) * 15125.

Note that the width is always a multiple of 64 bits.

I'll have to count the set bits for each column in each page of a PDF (which could be tens of thousands of pages) after being ripped.

I first approached the problem with a basic solution of just looping through each byte and counting the number of set bits and placing the results in a vector. I've since whittled down the algorithm to whats shown below. I've gone from a execution time of ~350ms to ~120ms.

static void count_dots( )
{
    using namespace diag;
    using namespace std::chrono;

    std::vector<std::size_t> dot_counts( 13952, 0 );
    uint64_t* ptr_dot_counts{ dot_counts.data( ) };

    std::vector<uint64_t> ripped_pdf_data( 3297250, 0xFFFFFFFFFFFFFFFFUL );
    const uint64_t* ptr_data{ ripped_pdf_data.data( ) };

    std::size_t line_count{ 0 };
    std::size_t counter{ ripped_pdf_data.size( ) };

    stopwatch sw;
    sw.start( );

    while( counter > 0 )
    {
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 7 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 6 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 5 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 4 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 3 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 2 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 1 ) & 0x0000000000000001UL ) >> 0;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0100000000000000UL ) >> 56;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0001000000000000UL ) >> 48;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000010000000000UL ) >> 40;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000100000000UL ) >> 32;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000001000000UL ) >> 24;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000000010000UL ) >> 16;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000000000100UL ) >> 8;
        *ptr_dot_counts++ += ( ( *ptr_data >> 0 ) & 0x0000000000000001UL ) >> 0;

        ++ptr_data;
        --counter;
        if( ++line_count >= 218 )
        {
            ptr_dot_counts = dot_counts.data( );
            line_count = 0;
        }
    }   

    sw.stop( );
    std::cout << sw.elapsed<milliseconds>( ) << "ms
";
}

Unfortunately this is still going to add a lot of extra processing time which isn't going to be acceptable.

The above code is ugly and wont win any beauty contests but it has helped in reducing execution time. Since the original version I wrote I've done the following:

  • Use pointers instead of indexers
  • Process the data in chunks of uint64 instead of uint8
  • Manually unroll the for loop for traversing each bit in each byte of a uint64
  • Use a final bit shift instead of __popcnt64 for counting the set bit after masking

For this test I'm generating phony ripped data where each bit is set to 1. The dot_counts vector should contain 15125 for each element after the test has completed.

I'm hoping some folks here can help me in getting the algorithms average execution time below 100ms. I do not care what-so-ever about portability here.

  • The target machine's CPU: Xeon E5-2680 v4 - Intel
  • Compiler: MSVC++ 14.23
  • OS: Windows 10
  • C++ Version: C++17
  • Compiler flags: /O2 /arch:AVX2

A very similar question was asked ~8 years ago: How to quickly count bits into separate bins in a series of ints on Sandy Bridge?

(Editor's note: perhaps you missed Count each bit-position separately over many 64-bit bitmasks, with AVX but not AVX2 which has some more recent faster answers, at least for going down a column instead of along a row in contiguous memory. Maybe you can go 1 or 2 cache-lines wide down a column so you can keep your counters hot in SIMD registers.)

When I compare what I have thus far to the accepted answer I'm fairly close. I was already processing in chunks of uint64 instead of uint8. I'm just wondering if there is more I can do, whether that be with intrinsics, assembly, or something simple like changing what data structures I'm using.

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

It could be done with AVX2, as tagged.

In order to make this work out properly, I recommend vector<uint16_t> for the counts. Adding into the counts is the biggest problem, and the more we need to widen, the bigger the problem. uint16_t is enough to count one page, so you can count one page at the time and add the counters into a set of wider counters for the totals. That is some overhead, but much less than having to widen more in the main loop.

The big-endian ordering of the counts is very annoying, introducing even more shuffles to get it right. So I recommend getting it wrong and reordering the counts later (maybe during summing them into the totals?). The order of "right shift by 7 first, then 6, then 5" can be maintained for free, because we get to choose the shift counts for the 64bit blocks any way we want. So in the code below, the actual order of counts is:

  • Bit 7 of the least significant byte,
  • Bit 7 of the second byte
  • ...
  • Bit 7 of the most significant byte,
  • Bit 6 of the least significant byte,
  • ...

So every group of 8 is reversed. (at least this is what I intended to do, AVX2 unpacks are confusing)

Code (not tested):

while( counter > 0 )
{
    __m256i data = _mm256_set1_epi64x(*ptr_data);        
    __m256i data1 = _mm256_srlv_epi64(data, _mm256_set_epi64x(4, 6, 5, 7));
    __m256i data2 = _mm256_srlv_epi64(data, _mm256_set_epi64x(0, 2, 1, 3));
    data1 = _mm256_and_si256(data1, _mm256_set1_epi8(1));
    data2 = _mm256_and_si256(data2, _mm256_set1_epi8(1));

    __m256i zero = _mm256_setzero_si256();

    __m256i c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
    c = _mm256_add_epi16(_mm256_unpacklo_epi8(data1, zero), c);
    _mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c);

    c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
    c = _mm256_add_epi16(_mm256_unpackhi_epi8(data1, zero), c);
    _mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c);

    c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
    c = _mm256_add_epi16(_mm256_unpacklo_epi8(data2, zero), c);
    _mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c);

    c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);
    c = _mm256_add_epi16(_mm256_unpackhi_epi8(data2, zero), c);
    _mm256_storeu_si256((__m256i*)&ptr_dot_counts[48], c);

    ptr_dot_counts += 64;
    ++ptr_data;
    --counter;
    if( ++line_count >= 218 )
    {
        ptr_dot_counts = dot_counts.data( );
        line_count = 0;
    }
}

This can be further unrolled, handling multiple rows at once. That is good because, as mentioned earlier, summing into the counters is the biggest problem, and unrolling by rows would do less of that and more plain summing in registers.

Somme intrinsics used:

  • _mm256_set1_epi64x, copies one int64_t to all 4 of the 64bit elements of the vector. Also fine for uint64_t.
  • _mm256_set_epi64x, turns 4 64bit values into a vector.
  • _mm256_srlv_epi64, shift right logical, with variable count (can be a different count for each element).
  • _mm256_and_si256, just bitwise AND.
  • _mm256_add_epi16, addition, works on 16bit elements.
  • _mm256_unpacklo_epi8 and _mm256_unpackhi_epi8, probably best explained by the diagrams on that page

It's possible to sum "vertically", using one uint64_t to hold all the 0th bits of the 64 individual sums, an other uint64_t to hold all the 1st bits of the sums etc. The addition can be done by emulating full adders (the circuit component) with bitwise arithmetic. Then instead of adding just 0 or 1 to the counters, bigger numbers are added all at once.

The vertical sums can also be vectorized, but that would significantly inflate the code that adds the vertical sums to the column sums, so I didn't do that here. It should help, but it's just a lot of code.

Example (not tested):

size_t y;
// sum 7 rows at once
for (y = 0; (y + 6) < 15125; y += 7) {
    ptr_dot_counts = dot_counts.data( );
    ptr_data = ripped_pdf_data.data( ) + y * 218;
    for (size_t x = 0; x < 218; x++) {
        uint64_t dataA = ptr_data[0];
        uint64_t dataB = ptr_data[218];
        uint64_t dataC = ptr_data[218 * 2];
        uint64_t dataD = ptr_data[218 * 3];
        uint64_t dataE = ptr_data[218 * 4];
        uint64_t dataF = ptr_data[218 * 5];
        uint64_t dataG = ptr_data[218 * 6];
        // vertical sums, 7 bits to 3
        uint64_t abc0 = (dataA ^ dataB) ^ dataC;
        uint64_t abc1 = (dataA ^ dataB) & dataC | (dataA & dataB);
        uint64_t def0 = (dataD ^ dataE) ^ dataF;
        uint64_t def1 = (dataD ^ dataE) & dataF | (dataD & dataE);
        uint64_t bit0 = (abc0 ^ def0) ^ dataG;
        uint64_t c1   = (abc0 ^ def0) & dataG | (abc0 & def0);
        uint64_t bit1 = (abc1 ^ def1) ^ c1;
        uint64_t bit2 = (abc1 ^ def1) & c1 | (abc1 & def1);
        // add vertical sums to column counts
        __m256i bit0v = _mm256_set1_epi64x(bit0);
        __m256i data01 = _mm256_srlv_epi64(bit0v, _mm256_set_epi64x(4, 6, 5, 7));
        __m256i data02 = _mm256_srlv_epi64(bit0v, _mm256_set_epi64x(0, 2, 1, 3));
        data01 = _mm256_and_si256(data01, _mm256_set1_epi8(1));
        data02 = _mm256_and_si256(data02, _mm256_set1_epi8(1));
        __m256i bit1v = _mm256_set1_epi64x(bit1);
        __m256i data11 = _mm256_srlv_epi64(bit1v, _mm256_set_epi64x(4, 6, 5, 7));
        __m256i data12 = _mm256_srlv_epi64(bit1v, _mm256_set_epi64x(0, 2, 1, 3));
        data11 = _mm256_and_si256(data11, _mm256_set1_epi8(1));
        data12 = _mm256_and_si256(data12, _mm256_set1_epi8(1));
        data11 = _mm256_add_epi8(data11, data11);
        data12 = _mm256_add_epi8(data12, data12);
        __m256i bit2v = _mm256_set1_epi64x(bit2);
        __m256i data21 = _mm256_srlv_epi64(bit2v, _mm256_set_epi64x(4, 6, 5, 7));
        __m256i data22 = _mm256_srlv_epi64(bit2v, _mm256_set_epi64x(0, 2, 1, 3));
        data21 = _mm256_and_si256(data21, _mm256_set1_epi8(1));
        data22 = _mm256_and_si256(data22, _mm256_set1_epi8(1));
        data21 = _mm256_slli_epi16(data21, 2);
        data22 = _mm256_slli_epi16(data22, 2);
        __m256i data1 = _mm256_add_epi8(_mm256_add_epi8(data01, data11), data21);
        __m256i data2 = _mm256_add_epi8(_mm256_add_epi8(data02, data12), data22);

        __m256i zero = _mm256_setzero_si256();

        __m256i c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
        c = _mm256_add_epi16(_mm256_unpacklo_epi8(data1, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
        c = _mm256_add_epi16(_mm256_unpackhi_epi8(data1, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
        c = _mm256_add_epi16(_mm256_unpacklo_epi8(data2, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);
        c = _mm256_add_epi16(_mm256_unpackhi_epi8(data2, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[48], c);


        ptr_dot_counts += 64;
        ++ptr_data;
    }
}
// leftover rows
for (; y < 15125; y++) {
    ptr_dot_counts = dot_counts.data( );
    ptr_data = ripped_pdf_data.data( ) + y * 218;
    for (size_t x = 0; x < 218; x++) {
        __m256i data = _mm256_set1_epi64x(*ptr_data);
        __m256i data1 = _mm256_srlv_epi64(data, _mm256_set_epi64x(4, 6, 5, 7));
        __m256i data2 = _mm256_srlv_epi64(data, _mm256_set_epi64x(0, 2, 1, 3));
        data1 = _mm256_and_si256(data1, _mm256_set1_epi8(1));
        data2 = _mm256_and_si256(data2, _mm256_set1_epi8(1));

        __m256i zero = _mm256_setzero_si256();

        __m256i c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
        c = _mm256_add_epi16(_mm256_unpacklo_epi8(data1, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
        c = _mm256_add_epi16(_mm256_unpackhi_epi8(data1, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
        c = _mm256_add_epi16(_mm256_unpacklo_epi8(data2, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c);

        c = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);
        c = _mm256_add_epi16(_mm256_unpackhi_epi8(data2, zero), c);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[48], c);


        ptr_dot_counts += 64;
        ++ptr_data;
    }
}

The second best so far was a simpler approach, more like the first version except doing runs of yloopLen rows at once to take advantage of fast 8bit sums:

size_t yloopLen = 32;
size_t yblock = yloopLen * 1;
size_t yy;
for (yy = 0; yy < 15125; yy += yblock) {
    for (size_t x = 0; x < 218; x++) {
        ptr_data = ripped_pdf_data.data() + x;
        ptr_dot_counts = dot_counts.data() + x * 64;
        __m256i zero = _mm256_setzero_si256();

        __m256i c1 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[0]);
        __m256i c2 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[16]);
        __m256i c3 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[32]);
        __m256i c4 = _mm256_loadu_si256((__m256i*)&ptr_dot_counts[48]);

        size_t end = std::min(yy + yblock, size_t(15125));
        size_t y;
        for (y = yy; y < end; y += yloopLen) {
            size_t len = std::min(size_t(yloopLen), end - y);
            __m256i count1 = zero;
            __m256i count2 = zero;

            for (size_t t = 0; t < len; t++) {
                __m256i data = _mm256_set1_epi64x(ptr_data[(y + t) * 218]);
                __m256i data1 = _mm256_srlv_epi64(data, _mm256_set_epi64x(4, 6, 5, 7));
                __m256i data2 = _mm256_srlv_epi64(data, _mm256_set_epi64x(0, 2, 1, 3));
                data1 = _mm256_and_si256(data1, _mm256_set1_epi8(1));
                data2 = _mm256_and_si256(data2, _mm256_set1_epi8(1));
                count1 = _mm256_add_epi8(count1, data1);
                count2 = _mm256_add_epi8(count2, data2);
            }

            c1 = _mm256_add_epi16(_mm256_unpacklo_epi8(count1, zero), c1);
            c2 = _mm256_add_epi16(_mm256_unpackhi_epi8(count1, zero), c2);
            c3 = _mm256_add_epi16(_mm256_unpacklo_epi8(count2, zero), c3);
            c4 = _mm256_add_epi16(_mm256_unpackhi_epi8(count2, zero), c4);
        }

        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[0], c1);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[16], c2);
        _mm256_storeu_si256((__m256i*)&ptr_dot_counts[32], c3);
        _mm256_storeu_si256((__m2

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...