1#!/usr/bin/env python
2
3'''
4/**************************************************************************
5 *
6 * Copyright 2009-2010 VMware, Inc.
7 * All Rights Reserved.
8 *
9 * Permission is hereby granted, free of charge, to any person obtaining a
10 * copy of this software and associated documentation files (the
11 * "Software"), to deal in the Software without restriction, including
12 * without limitation the rights to use, copy, modify, merge, publish,
13 * distribute, sub license, and/or sell copies of the Software, and to
14 * permit persons to whom the Software is furnished to do so, subject to
15 * the following conditions:
16 *
17 * The above copyright notice and this permission notice (including the
18 * next paragraph) shall be included in all copies or substantial portions
19 * of the Software.
20 *
21 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
22 * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
23 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT.
24 * IN NO EVENT SHALL VMWARE AND/OR ITS SUPPLIERS BE LIABLE FOR
25 * ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
26 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
27 * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
28 *
29 **************************************************************************/
30
31/**
32 * @file
33 * Pixel format packing and unpacking functions.
34 *
35 * @author Jose Fonseca <jfonseca@vmware.com>
36 */
37'''
38from __future__ import print_function
39
40
41from u_format_parse import *
42
43
44def inv_swizzles(swizzles):
45    '''Return an array[4] of inverse swizzle terms'''
46    '''Only pick the first matching value to avoid l8 getting blue and i8 getting alpha'''
47    inv_swizzle = [None]*4
48    for i in range(4):
49        swizzle = swizzles[i]
50        if swizzle < 4 and inv_swizzle[swizzle] == None:
51            inv_swizzle[swizzle] = i
52    return inv_swizzle
53
54def print_channels(format, func):
55    if format.nr_channels() <= 1:
56        func(format.le_channels, format.le_swizzles)
57    else:
58        print('#ifdef PIPE_ARCH_BIG_ENDIAN')
59        func(format.be_channels, format.be_swizzles)
60        print('#else')
61        func(format.le_channels, format.le_swizzles)
62        print('#endif')
63
64def generate_format_type(format):
65    '''Generate a structure that describes the format.'''
66
67    assert format.layout == PLAIN
68
69    def generate_bitfields(channels, swizzles):
70        for channel in channels:
71            if channel.type == VOID:
72                if channel.size:
73                    print('      unsigned %s:%u;' % (channel.name, channel.size))
74            elif channel.type == UNSIGNED:
75                print('      unsigned %s:%u;' % (channel.name, channel.size))
76            elif channel.type in (SIGNED, FIXED):
77                print('      int %s:%u;' % (channel.name, channel.size))
78            elif channel.type == FLOAT:
79                if channel.size == 64:
80                    print('      double %s;' % (channel.name))
81                elif channel.size == 32:
82                    print('      float %s;' % (channel.name))
83                else:
84                    print('      unsigned %s:%u;' % (channel.name, channel.size))
85            else:
86                assert 0
87
88    def generate_full_fields(channels, swizzles):
89        for channel in channels:
90            assert channel.size % 8 == 0 and is_pot(channel.size)
91            if channel.type == VOID:
92                if channel.size:
93                    print('      uint%u_t %s;' % (channel.size, channel.name))
94            elif channel.type == UNSIGNED:
95                print('      uint%u_t %s;' % (channel.size, channel.name))
96            elif channel.type in (SIGNED, FIXED):
97                print('      int%u_t %s;' % (channel.size, channel.name))
98            elif channel.type == FLOAT:
99                if channel.size == 64:
100                    print('      double %s;' % (channel.name))
101                elif channel.size == 32:
102                    print('      float %s;' % (channel.name))
103                elif channel.size == 16:
104                    print('      uint16_t %s;' % (channel.name))
105                else:
106                    assert 0
107            else:
108                assert 0
109
110    print('union util_format_%s {' % format.short_name())
111
112    if format.block_size() in (8, 16, 32, 64):
113        print('   uint%u_t value;' % (format.block_size(),))
114
115    use_bitfields = False
116    for channel in format.le_channels:
117        if channel.size % 8 or not is_pot(channel.size):
118            use_bitfields = True
119
120    print('   struct {')
121    if use_bitfields:
122        print_channels(format, generate_bitfields)
123    else:
124        print_channels(format, generate_full_fields)
125    print('   } chan;')
126    print('};')
127    print()
128
129
130def is_format_supported(format):
131    '''Determines whether we actually have the plumbing necessary to generate the
132    to read/write to/from this format.'''
133
134    # FIXME: Ideally we would support any format combination here.
135
136    if format.layout != PLAIN:
137        return False
138
139    for i in range(4):
140        channel = format.le_channels[i]
141        if channel.type not in (VOID, UNSIGNED, SIGNED, FLOAT, FIXED):
142            return False
143        if channel.type == FLOAT and channel.size not in (16, 32, 64):
144            return False
145
146    return True
147
148def native_type(format):
149    '''Get the native appropriate for a format.'''
150
151    if format.name == 'PIPE_FORMAT_R11G11B10_FLOAT':
152        return 'uint32_t'
153    if format.name == 'PIPE_FORMAT_R9G9B9E5_FLOAT':
154        return 'uint32_t'
155
156    if format.layout == PLAIN:
157        if not format.is_array():
158            # For arithmetic pixel formats return the integer type that matches the whole pixel
159            return 'uint%u_t' % format.block_size()
160        else:
161            # For array pixel formats return the integer type that matches the color channel
162            channel = format.array_element()
163            if channel.type in (UNSIGNED, VOID):
164                return 'uint%u_t' % channel.size
165            elif channel.type in (SIGNED, FIXED):
166                return 'int%u_t' % channel.size
167            elif channel.type == FLOAT:
168                if channel.size == 16:
169                    return 'uint16_t'
170                elif channel.size == 32:
171                    return 'float'
172                elif channel.size == 64:
173                    return 'double'
174                else:
175                    assert False
176            else:
177                assert False
178    else:
179        assert False
180
181
182def intermediate_native_type(bits, sign):
183    '''Find a native type adequate to hold intermediate results of the request bit size.'''
184
185    bytes = 4 # don't use anything smaller than 32bits
186    while bytes * 8 < bits:
187        bytes *= 2
188    bits = bytes*8
189
190    if sign:
191        return 'int%u_t' % bits
192    else:
193        return 'uint%u_t' % bits
194
195
196def get_one_shift(type):
197    '''Get the number of the bit that matches unity for this type.'''
198    if type.type == 'FLOAT':
199        assert False
200    if not type.norm:
201        return 0
202    if type.type == UNSIGNED:
203        return type.size
204    if type.type == SIGNED:
205        return type.size - 1
206    if type.type == FIXED:
207        return type.size / 2
208    assert False
209
210
211def truncate_mantissa(x, bits):
212    '''Truncate an integer so it can be represented exactly with a floating
213    point mantissa'''
214
215    assert isinstance(x, (int, long))
216
217    s = 1
218    if x < 0:
219        s = -1
220        x = -x
221
222    # We can represent integers up to mantissa + 1 bits exactly
223    mask = (1 << (bits + 1)) - 1
224
225    # Slide the mask until the MSB matches
226    shift = 0
227    while (x >> shift) & ~mask:
228        shift += 1
229
230    x &= mask << shift
231    x *= s
232    return x
233
234
235def value_to_native(type, value):
236    '''Get the value of unity for this type.'''
237    if type.type == FLOAT:
238        if type.size <= 32 \
239            and isinstance(value, (int, long)):
240            return truncate_mantissa(value, 23)
241        return value
242    if type.type == FIXED:
243        return int(value * (1 << (type.size/2)))
244    if not type.norm:
245        return int(value)
246    if type.type == UNSIGNED:
247        return int(value * ((1 << type.size) - 1))
248    if type.type == SIGNED:
249        return int(value * ((1 << (type.size - 1)) - 1))
250    assert False
251
252
253def native_to_constant(type, value):
254    '''Get the value of unity for this type.'''
255    if type.type == FLOAT:
256        if type.size <= 32:
257            return "%.1ff" % float(value)
258        else:
259            return "%.1f" % float(value)
260    else:
261        return str(int(value))
262
263
264def get_one(type):
265    '''Get the value of unity for this type.'''
266    return value_to_native(type, 1)
267
268
269def clamp_expr(src_channel, dst_channel, dst_native_type, value):
270    '''Generate the expression to clamp the value in the source type to the
271    destination type range.'''
272
273    if src_channel == dst_channel:
274        return value
275
276    src_min = src_channel.min()
277    src_max = src_channel.max()
278    dst_min = dst_channel.min()
279    dst_max = dst_channel.max()
280
281    # Translate the destination range to the src native value
282    dst_min_native = native_to_constant(src_channel, value_to_native(src_channel, dst_min))
283    dst_max_native = native_to_constant(src_channel, value_to_native(src_channel, dst_max))
284
285    if src_min < dst_min and src_max > dst_max:
286        return 'CLAMP(%s, %s, %s)' % (value, dst_min_native, dst_max_native)
287
288    if src_max > dst_max:
289        return 'MIN2(%s, %s)' % (value, dst_max_native)
290
291    if src_min < dst_min:
292        return 'MAX2(%s, %s)' % (value, dst_min_native)
293
294    return value
295
296
297def conversion_expr(src_channel,
298                    dst_channel, dst_native_type,
299                    value,
300                    clamp=True,
301                    src_colorspace = RGB,
302                    dst_colorspace = RGB):
303    '''Generate the expression to convert a value between two types.'''
304
305    if src_colorspace != dst_colorspace:
306        if src_colorspace == SRGB:
307            assert src_channel.type == UNSIGNED
308            assert src_channel.norm
309            assert src_channel.size <= 8
310            assert src_channel.size >= 4
311            assert dst_colorspace == RGB
312            if src_channel.size < 8:
313                value = '%s << %x | %s >> %x' % (value, 8 - src_channel.size, value, 2 * src_channel.size - 8)
314            if dst_channel.type == FLOAT:
315                return 'util_format_srgb_8unorm_to_linear_float(%s)' % value
316            else:
317                assert dst_channel.type == UNSIGNED
318                assert dst_channel.norm
319                assert dst_channel.size == 8
320                return 'util_format_srgb_to_linear_8unorm(%s)' % value
321        elif dst_colorspace == SRGB:
322            assert dst_channel.type == UNSIGNED
323            assert dst_channel.norm
324            assert dst_channel.size <= 8
325            assert src_colorspace == RGB
326            if src_channel.type == FLOAT:
327                value =  'util_format_linear_float_to_srgb_8unorm(%s)' % value
328            else:
329                assert src_channel.type == UNSIGNED
330                assert src_channel.norm
331                assert src_channel.size == 8
332                value = 'util_format_linear_to_srgb_8unorm(%s)' % value
333            # XXX rounding is all wrong.
334            if dst_channel.size < 8:
335                return '%s >> %x' % (value, 8 - dst_channel.size)
336            else:
337                return value
338        elif src_colorspace == ZS:
339            pass
340        elif dst_colorspace == ZS:
341            pass
342        else:
343            assert 0
344
345    if src_channel == dst_channel:
346        return value
347
348    src_type = src_channel.type
349    src_size = src_channel.size
350    src_norm = src_channel.norm
351    src_pure = src_channel.pure
352
353    # Promote half to float
354    if src_type == FLOAT and src_size == 16:
355        value = 'util_half_to_float(%s)' % value
356        src_size = 32
357
358    # Special case for float <-> ubytes for more accurate results
359    # Done before clamping since these functions already take care of that
360    if src_type == UNSIGNED and src_norm and src_size == 8 and dst_channel.type == FLOAT and dst_channel.size == 32:
361        return 'ubyte_to_float(%s)' % value
362    if src_type == FLOAT and src_size == 32 and dst_channel.type == UNSIGNED and dst_channel.norm and dst_channel.size == 8:
363        return 'float_to_ubyte(%s)' % value
364
365    if clamp:
366        if dst_channel.type != FLOAT or src_type != FLOAT:
367            value = clamp_expr(src_channel, dst_channel, dst_native_type, value)
368
369    if src_type in (SIGNED, UNSIGNED) and dst_channel.type in (SIGNED, UNSIGNED):
370        if not src_norm and not dst_channel.norm:
371            # neither is normalized -- just cast
372            return '(%s)%s' % (dst_native_type, value)
373
374        src_one = get_one(src_channel)
375        dst_one = get_one(dst_channel)
376
377        if src_one > dst_one and src_norm and dst_channel.norm:
378            # We can just bitshift
379            src_shift = get_one_shift(src_channel)
380            dst_shift = get_one_shift(dst_channel)
381            value = '(%s >> %s)' % (value, src_shift - dst_shift)
382        else:
383            # We need to rescale using an intermediate type big enough to hold the multiplication of both
384            tmp_native_type = intermediate_native_type(src_size + dst_channel.size, src_channel.sign and dst_channel.sign)
385            value = '((%s)%s)' % (tmp_native_type, value)
386            value = '(%s * 0x%x / 0x%x)' % (value, dst_one, src_one)
387        value = '(%s)%s' % (dst_native_type, value)
388        return value
389
390    # Promote to either float or double
391    if src_type != FLOAT:
392        if src_norm or src_type == FIXED:
393            one = get_one(src_channel)
394            if src_size <= 23:
395                value = '(%s * (1.0f/0x%x))' % (value, one)
396                if dst_channel.size <= 32:
397                    value = '(float)%s' % value
398                src_size = 32
399            else:
400                # bigger than single precision mantissa, use double
401                value = '(%s * (1.0/0x%x))' % (value, one)
402                src_size = 64
403            src_norm = False
404        else:
405            if src_size <= 23 or dst_channel.size <= 32:
406                value = '(float)%s' % value
407                src_size = 32
408            else:
409                # bigger than single precision mantissa, use double
410                value = '(double)%s' % value
411                src_size = 64
412        src_type = FLOAT
413
414    # Convert double or float to non-float
415    if dst_channel.type != FLOAT:
416        if dst_channel.norm or dst_channel.type == FIXED:
417            dst_one = get_one(dst_channel)
418            if dst_channel.size <= 23:
419                value = 'util_iround(%s * 0x%x)' % (value, dst_one)
420            else:
421                # bigger than single precision mantissa, use double
422                value = '(%s * (double)0x%x)' % (value, dst_one)
423        value = '(%s)%s' % (dst_native_type, value)
424    else:
425        # Cast double to float when converting to either half or float
426        if dst_channel.size <= 32 and src_size > 32:
427            value = '(float)%s' % value
428            src_size = 32
429
430        if dst_channel.size == 16:
431            value = 'util_float_to_half(%s)' % value
432        elif dst_channel.size == 64 and src_size < 64:
433            value = '(double)%s' % value
434
435    return value
436
437
438def generate_unpack_kernel(format, dst_channel, dst_native_type):
439
440    if not is_format_supported(format):
441        return
442
443    assert format.layout == PLAIN
444
445    src_native_type = native_type(format)
446
447    def unpack_from_bitmask(channels, swizzles):
448        depth = format.block_size()
449        print('         uint%u_t value = *(const uint%u_t *)src;' % (depth, depth))
450
451        # Declare the intermediate variables
452        for i in range(format.nr_channels()):
453            src_channel = channels[i]
454            if src_channel.type == UNSIGNED:
455                print('         uint%u_t %s;' % (depth, src_channel.name))
456            elif src_channel.type == SIGNED:
457                print('         int%u_t %s;' % (depth, src_channel.name))
458
459        # Compute the intermediate unshifted values
460        for i in range(format.nr_channels()):
461            src_channel = channels[i]
462            value = 'value'
463            shift = src_channel.shift
464            if src_channel.type == UNSIGNED:
465                if shift:
466                    value = '%s >> %u' % (value, shift)
467                if shift + src_channel.size < depth:
468                    value = '(%s) & 0x%x' % (value, (1 << src_channel.size) - 1)
469            elif src_channel.type == SIGNED:
470                if shift + src_channel.size < depth:
471                    # Align the sign bit
472                    lshift = depth - (shift + src_channel.size)
473                    value = '%s << %u' % (value, lshift)
474                # Cast to signed
475                value = '(int%u_t)(%s) ' % (depth, value)
476                if src_channel.size < depth:
477                    # Align the LSB bit
478                    rshift = depth - src_channel.size
479                    value = '(%s) >> %u' % (value, rshift)
480            else:
481                value = None
482
483            if value is not None:
484                print('         %s = %s;' % (src_channel.name, value))
485
486        # Convert, swizzle, and store final values
487        for i in range(4):
488            swizzle = swizzles[i]
489            if swizzle < 4:
490                src_channel = channels[swizzle]
491                src_colorspace = format.colorspace
492                if src_colorspace == SRGB and i == 3:
493                    # Alpha channel is linear
494                    src_colorspace = RGB
495                value = src_channel.name
496                value = conversion_expr(src_channel,
497                                        dst_channel, dst_native_type,
498                                        value,
499                                        src_colorspace = src_colorspace)
500            elif swizzle == SWIZZLE_0:
501                value = '0'
502            elif swizzle == SWIZZLE_1:
503                value = get_one(dst_channel)
504            elif swizzle == SWIZZLE_NONE:
505                value = '0'
506            else:
507                assert False
508            print('         dst[%u] = %s; /* %s */' % (i, value, 'rgba'[i]))
509
510    def unpack_from_union(channels, swizzles):
511        print('         union util_format_%s pixel;' % format.short_name())
512        print('         memcpy(&pixel, src, sizeof pixel);')
513
514        for i in range(4):
515            swizzle = swizzles[i]
516            if swizzle < 4:
517                src_channel = channels[swizzle]
518                src_colorspace = format.colorspace
519                if src_colorspace == SRGB and i == 3:
520                    # Alpha channel is linear
521                    src_colorspace = RGB
522                value = 'pixel.chan.%s' % src_channel.name
523                value = conversion_expr(src_channel,
524                                        dst_channel, dst_native_type,
525                                        value,
526                                        src_colorspace = src_colorspace)
527            elif swizzle == SWIZZLE_0:
528                value = '0'
529            elif swizzle == SWIZZLE_1:
530                value = get_one(dst_channel)
531            elif swizzle == SWIZZLE_NONE:
532                value = '0'
533            else:
534                assert False
535            print('         dst[%u] = %s; /* %s */' % (i, value, 'rgba'[i]))
536
537    if format.is_bitmask():
538        print_channels(format, unpack_from_bitmask)
539    else:
540        print_channels(format, unpack_from_union)
541
542
543def generate_pack_kernel(format, src_channel, src_native_type):
544
545    if not is_format_supported(format):
546        return
547
548    dst_native_type = native_type(format)
549
550    assert format.layout == PLAIN
551
552    def pack_into_bitmask(channels, swizzles):
553        inv_swizzle = inv_swizzles(swizzles)
554
555        depth = format.block_size()
556        print('         uint%u_t value = 0;' % depth)
557
558        for i in range(4):
559            dst_channel = channels[i]
560            shift = dst_channel.shift
561            if inv_swizzle[i] is not None:
562                value ='src[%u]' % inv_swizzle[i]
563                dst_colorspace = format.colorspace
564                if dst_colorspace == SRGB and inv_swizzle[i] == 3:
565                    # Alpha channel is linear
566                    dst_colorspace = RGB
567                value = conversion_expr(src_channel,
568                                        dst_channel, dst_native_type,
569                                        value,
570                                        dst_colorspace = dst_colorspace)
571                if dst_channel.type in (UNSIGNED, SIGNED):
572                    if shift + dst_channel.size < depth:
573                        value = '(%s) & 0x%x' % (value, (1 << dst_channel.size) - 1)
574                    if shift:
575                        value = '(%s) << %u' % (value, shift)
576                    if dst_channel.type == SIGNED:
577                        # Cast to unsigned
578                        value = '(uint%u_t)(%s) ' % (depth, value)
579                else:
580                    value = None
581                if value is not None:
582                    print('         value |= %s;' % (value))
583
584        print('         *(uint%u_t *)dst = value;' % depth)
585
586    def pack_into_union(channels, swizzles):
587        inv_swizzle = inv_swizzles(swizzles)
588
589        print('         union util_format_%s pixel;' % format.short_name())
590
591        for i in range(4):
592            dst_channel = channels[i]
593            width = dst_channel.size
594            if inv_swizzle[i] is None:
595                continue
596            dst_colorspace = format.colorspace
597            if dst_colorspace == SRGB and inv_swizzle[i] == 3:
598                # Alpha channel is linear
599                dst_colorspace = RGB
600            value ='src[%u]' % inv_swizzle[i]
601            value = conversion_expr(src_channel,
602                                    dst_channel, dst_native_type,
603                                    value,
604                                    dst_colorspace = dst_colorspace)
605            print('         pixel.chan.%s = %s;' % (dst_channel.name, value))
606
607        print('         memcpy(dst, &pixel, sizeof pixel);')
608
609    if format.is_bitmask():
610        print_channels(format, pack_into_bitmask)
611    else:
612        print_channels(format, pack_into_union)
613
614
615def generate_format_unpack(format, dst_channel, dst_native_type, dst_suffix):
616    '''Generate the function to unpack pixels from a particular format'''
617
618    name = format.short_name()
619
620    print('static inline void')
621    print('util_format_%s_unpack_%s(%s *dst_row, unsigned dst_stride, const uint8_t *src_row, unsigned src_stride, unsigned width, unsigned height)' % (name, dst_suffix, dst_native_type))
622    print('{')
623
624    if is_format_supported(format):
625        print('   unsigned x, y;')
626        print('   for(y = 0; y < height; y += %u) {' % (format.block_height,))
627        print('      %s *dst = dst_row;' % (dst_native_type))
628        print('      const uint8_t *src = src_row;')
629        print('      for(x = 0; x < width; x += %u) {' % (format.block_width,))
630
631        generate_unpack_kernel(format, dst_channel, dst_native_type)
632
633        print('         src += %u;' % (format.block_size() / 8,))
634        print('         dst += 4;')
635        print('      }')
636        print('      src_row += src_stride;')
637        print('      dst_row += dst_stride/sizeof(*dst_row);')
638        print('   }')
639
640    print('}')
641    print()
642
643
644def generate_format_pack(format, src_channel, src_native_type, src_suffix):
645    '''Generate the function to pack pixels to a particular format'''
646
647    name = format.short_name()
648
649    print('static inline void')
650    print('util_format_%s_pack_%s(uint8_t *dst_row, unsigned dst_stride, const %s *src_row, unsigned src_stride, unsigned width, unsigned height)' % (name, src_suffix, src_native_type))
651    print('{')
652
653    if is_format_supported(format):
654        print('   unsigned x, y;')
655        print('   for(y = 0; y < height; y += %u) {' % (format.block_height,))
656        print('      const %s *src = src_row;' % (src_native_type))
657        print('      uint8_t *dst = dst_row;')
658        print('      for(x = 0; x < width; x += %u) {' % (format.block_width,))
659
660        generate_pack_kernel(format, src_channel, src_native_type)
661
662        print('         src += 4;')
663        print('         dst += %u;' % (format.block_size() / 8,))
664        print('      }')
665        print('      dst_row += dst_stride;')
666        print('      src_row += src_stride/sizeof(*src_row);')
667        print('   }')
668
669    print('}')
670    print()
671
672
673def generate_format_fetch(format, dst_channel, dst_native_type, dst_suffix):
674    '''Generate the function to unpack pixels from a particular format'''
675
676    name = format.short_name()
677
678    print('static inline void')
679    print('util_format_%s_fetch_%s(%s *dst, const uint8_t *src, unsigned i, unsigned j)' % (name, dst_suffix, dst_native_type))
680    print('{')
681
682    if is_format_supported(format):
683        generate_unpack_kernel(format, dst_channel, dst_native_type)
684
685    print('}')
686    print()
687
688
689def is_format_hand_written(format):
690    return format.layout in ('s3tc', 'rgtc', 'etc', 'bptc', 'subsampled', 'other') or format.colorspace == ZS
691
692
693def generate(formats):
694    print()
695    print('#include "pipe/p_compiler.h"')
696    print('#include "u_math.h"')
697    print('#include "u_half.h"')
698    print('#include "u_format.h"')
699    print()
700
701