1 /*
2  * Copyright © 2017 Advanced Micro Devices, Inc.
3  * All Rights Reserved.
4  *
5  * Permission is hereby granted, free of charge, to any person obtaining
6  * a copy of this software and associated documentation files (the
7  * "Software"), to deal in the Software without restriction, including
8  * without limitation the rights to use, copy, modify, merge, publish,
9  * distribute, sub license, and/or sell copies of the Software, and to
10  * permit persons to whom the Software is furnished to do so, subject to
11  * the following conditions:
12  *
13  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
14  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
15  * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
16  * NON-INFRINGEMENT. IN NO EVENT SHALL THE COPYRIGHT HOLDERS, AUTHORS
17  * AND/OR ITS SUPPLIERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
19  * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE
20  * USE OR OTHER DEALINGS IN THE SOFTWARE.
21  *
22  * The above copyright notice and this permission notice (including the
23  * next paragraph) shall be included in all copies or substantial portions
24  * of the Software.
25  */
26 
27 // Coordinate class implementation
28 #include "addrcommon.h"
29 #include "coord.h"
30 
Coordinate()31 Coordinate::Coordinate()
32 {
33     dim = 'x';
34     ord = 0;
35 }
36 
Coordinate(INT_8 c,INT_32 n)37 Coordinate::Coordinate(INT_8 c, INT_32 n)
38 {
39     set(c, n);
40 }
41 
set(INT_8 c,INT_32 n)42 VOID Coordinate::set(INT_8 c, INT_32 n)
43 {
44     dim = c;
45     ord = static_cast<INT_8>(n);
46 }
47 
ison(UINT_32 x,UINT_32 y,UINT_32 z,UINT_32 s,UINT_32 m) const48 UINT_32 Coordinate::ison(UINT_32 x, UINT_32 y, UINT_32 z, UINT_32 s, UINT_32 m) const
49 {
50     UINT_32 bit = static_cast<UINT_32>(1ull << static_cast<UINT_32>(ord));
51     UINT_32 out = 0;
52 
53     switch (dim)
54     {
55     case 'm': out = m & bit; break;
56     case 's': out = s & bit; break;
57     case 'x': out = x & bit; break;
58     case 'y': out = y & bit; break;
59     case 'z': out = z & bit; break;
60     }
61     return (out != 0) ? 1 : 0;
62 }
63 
getdim()64 INT_8 Coordinate::getdim()
65 {
66     return dim;
67 }
68 
getord()69 INT_8 Coordinate::getord()
70 {
71     return ord;
72 }
73 
operator ==(const Coordinate & b)74 BOOL_32 Coordinate::operator==(const Coordinate& b)
75 {
76     return (dim == b.dim) && (ord == b.ord);
77 }
78 
operator <(const Coordinate & b)79 BOOL_32 Coordinate::operator<(const Coordinate& b)
80 {
81     BOOL_32 ret;
82 
83     if (dim == b.dim)
84     {
85         ret = ord < b.ord;
86     }
87     else
88     {
89         if (dim == 's' || b.dim == 'm')
90         {
91             ret = TRUE;
92         }
93         else if (b.dim == 's' || dim == 'm')
94         {
95             ret = FALSE;
96         }
97         else if (ord == b.ord)
98         {
99             ret = dim < b.dim;
100         }
101         else
102         {
103             ret = ord < b.ord;
104         }
105     }
106 
107     return ret;
108 }
109 
operator >(const Coordinate & b)110 BOOL_32 Coordinate::operator>(const Coordinate& b)
111 {
112     BOOL_32 lt = *this < b;
113     BOOL_32 eq = *this == b;
114     return !lt && !eq;
115 }
116 
operator <=(const Coordinate & b)117 BOOL_32 Coordinate::operator<=(const Coordinate& b)
118 {
119     return (*this < b) || (*this == b);
120 }
121 
operator >=(const Coordinate & b)122 BOOL_32 Coordinate::operator>=(const Coordinate& b)
123 {
124     return !(*this < b);
125 }
126 
operator !=(const Coordinate & b)127 BOOL_32 Coordinate::operator!=(const Coordinate& b)
128 {
129     return !(*this == b);
130 }
131 
operator ++(INT_32)132 Coordinate& Coordinate::operator++(INT_32)
133 {
134     ord++;
135     return *this;
136 }
137 
138 // CoordTerm
139 
CoordTerm()140 CoordTerm::CoordTerm()
141 {
142     num_coords = 0;
143 }
144 
Clear()145 VOID CoordTerm::Clear()
146 {
147     num_coords = 0;
148 }
149 
add(Coordinate & co)150 VOID CoordTerm::add(Coordinate& co)
151 {
152     // This function adds a coordinate INT_32o the list
153     // It will prevent the same coordinate from appearing,
154     // and will keep the list ordered from smallest to largest
155     UINT_32 i;
156 
157     for (i = 0; i < num_coords; i++)
158     {
159         if (m_coord[i] == co)
160         {
161             break;
162         }
163         if (m_coord[i] > co)
164         {
165             for (UINT_32 j = num_coords; j > i; j--)
166             {
167                 m_coord[j] = m_coord[j - 1];
168             }
169             m_coord[i] = co;
170             num_coords++;
171             break;
172         }
173     }
174 
175     if (i == num_coords)
176     {
177         m_coord[num_coords] = co;
178         num_coords++;
179     }
180 }
181 
add(CoordTerm & cl)182 VOID CoordTerm::add(CoordTerm& cl)
183 {
184     for (UINT_32 i = 0; i < cl.num_coords; i++)
185     {
186         add(cl.m_coord[i]);
187     }
188 }
189 
remove(Coordinate & co)190 BOOL_32 CoordTerm::remove(Coordinate& co)
191 {
192     BOOL_32 remove = FALSE;
193     for (UINT_32 i = 0; i < num_coords; i++)
194     {
195         if (m_coord[i] == co)
196         {
197             remove = TRUE;
198             num_coords--;
199         }
200 
201         if (remove)
202         {
203             m_coord[i] = m_coord[i + 1];
204         }
205     }
206     return remove;
207 }
208 
Exists(Coordinate & co)209 BOOL_32 CoordTerm::Exists(Coordinate& co)
210 {
211     BOOL_32 exists = FALSE;
212     for (UINT_32 i = 0; i < num_coords; i++)
213     {
214         if (m_coord[i] == co)
215         {
216             exists = TRUE;
217             break;
218         }
219     }
220     return exists;
221 }
222 
copyto(CoordTerm & cl)223 VOID CoordTerm::copyto(CoordTerm& cl)
224 {
225     cl.num_coords = num_coords;
226     for (UINT_32 i = 0; i < num_coords; i++)
227     {
228         cl.m_coord[i] = m_coord[i];
229     }
230 }
231 
getsize()232 UINT_32 CoordTerm::getsize()
233 {
234     return num_coords;
235 }
236 
getxor(UINT_32 x,UINT_32 y,UINT_32 z,UINT_32 s,UINT_32 m) const237 UINT_32 CoordTerm::getxor(UINT_32 x, UINT_32 y, UINT_32 z, UINT_32 s, UINT_32 m) const
238 {
239     UINT_32 out = 0;
240     for (UINT_32 i = 0; i < num_coords; i++)
241     {
242         out = out ^ m_coord[i].ison(x, y, z, s, m);
243     }
244     return out;
245 }
246 
getsmallest(Coordinate & co)247 VOID CoordTerm::getsmallest(Coordinate& co)
248 {
249     co = m_coord[0];
250 }
251 
Filter(INT_8 f,Coordinate & co,UINT_32 start,INT_8 axis)252 UINT_32 CoordTerm::Filter(INT_8 f, Coordinate& co, UINT_32 start, INT_8 axis)
253 {
254     for (UINT_32 i = start;  i < num_coords;)
255     {
256         if (((f == '<' && m_coord[i] < co) ||
257              (f == '>' && m_coord[i] > co) ||
258              (f == '=' && m_coord[i] == co)) &&
259             (axis == '\0' || axis == m_coord[i].getdim()))
260         {
261             for (UINT_32 j = i; j < num_coords - 1; j++)
262             {
263                 m_coord[j] = m_coord[j + 1];
264             }
265             num_coords--;
266         }
267         else
268         {
269             i++;
270         }
271     }
272     return num_coords;
273 }
274 
operator [](UINT_32 i)275 Coordinate& CoordTerm::operator[](UINT_32 i)
276 {
277     return m_coord[i];
278 }
279 
operator ==(const CoordTerm & b)280 BOOL_32 CoordTerm::operator==(const CoordTerm& b)
281 {
282     BOOL_32 ret = TRUE;
283 
284     if (num_coords != b.num_coords)
285     {
286         ret = FALSE;
287     }
288     else
289     {
290         for (UINT_32 i = 0; i < num_coords; i++)
291         {
292             // Note: the lists will always be in order, so we can compare the two lists at time
293             if (m_coord[i] != b.m_coord[i])
294             {
295                 ret = FALSE;
296                 break;
297             }
298         }
299     }
300     return ret;
301 }
302 
operator !=(const CoordTerm & b)303 BOOL_32 CoordTerm::operator!=(const CoordTerm& b)
304 {
305     return !(*this == b);
306 }
307 
exceedRange(UINT_32 xRange,UINT_32 yRange,UINT_32 zRange,UINT_32 sRange)308 BOOL_32 CoordTerm::exceedRange(UINT_32 xRange, UINT_32 yRange, UINT_32 zRange, UINT_32 sRange)
309 {
310     BOOL_32 exceed = FALSE;
311     for (UINT_32 i = 0; (i < num_coords) && (exceed == FALSE); i++)
312     {
313         UINT_32 subject;
314         switch (m_coord[i].getdim())
315         {
316             case 'x':
317                 subject = xRange;
318                 break;
319             case 'y':
320                 subject = yRange;
321                 break;
322             case 'z':
323                 subject = zRange;
324                 break;
325             case 's':
326                 subject = sRange;
327                 break;
328             case 'm':
329                 subject = 0;
330                 break;
331             default:
332                 // Invalid input!
333                 ADDR_ASSERT_ALWAYS();
334                 subject = 0;
335                 break;
336         }
337 
338         exceed = ((1u << m_coord[i].getord()) <= subject);
339     }
340 
341     return exceed;
342 }
343 
344 // coordeq
CoordEq()345 CoordEq::CoordEq()
346 {
347     m_numBits = 0;
348 }
349 
remove(Coordinate & co)350 VOID CoordEq::remove(Coordinate& co)
351 {
352     for (UINT_32 i = 0; i < m_numBits; i++)
353     {
354         m_eq[i].remove(co);
355     }
356 }
357 
Exists(Coordinate & co)358 BOOL_32 CoordEq::Exists(Coordinate& co)
359 {
360     BOOL_32 exists = FALSE;
361 
362     for (UINT_32 i = 0; i < m_numBits; i++)
363     {
364         if (m_eq[i].Exists(co))
365         {
366             exists = TRUE;
367         }
368     }
369     return exists;
370 }
371 
resize(UINT_32 n)372 VOID CoordEq::resize(UINT_32 n)
373 {
374     if (n > m_numBits)
375     {
376         for (UINT_32 i = m_numBits; i < n; i++)
377         {
378             m_eq[i].Clear();
379         }
380     }
381     m_numBits = n;
382 }
383 
getsize()384 UINT_32 CoordEq::getsize()
385 {
386     return m_numBits;
387 }
388 
solve(UINT_32 x,UINT_32 y,UINT_32 z,UINT_32 s,UINT_32 m) const389 UINT_64 CoordEq::solve(UINT_32 x, UINT_32 y, UINT_32 z, UINT_32 s, UINT_32 m) const
390 {
391     UINT_64 out = 0;
392     for (UINT_32 i = 0; i < m_numBits; i++)
393     {
394         if (m_eq[i].getxor(x, y, z, s, m) != 0)
395         {
396             out |= (1ULL << i);
397         }
398     }
399     return out;
400 }
401 
solveAddr(UINT_64 addr,UINT_32 sliceInM,UINT_32 & x,UINT_32 & y,UINT_32 & z,UINT_32 & s,UINT_32 & m) const402 VOID CoordEq::solveAddr(
403     UINT_64 addr, UINT_32 sliceInM,
404     UINT_32& x, UINT_32& y, UINT_32& z, UINT_32& s, UINT_32& m) const
405 {
406     UINT_32 xBitsValid = 0;
407     UINT_32 yBitsValid = 0;
408     UINT_32 zBitsValid = 0;
409     UINT_32 sBitsValid = 0;
410     UINT_32 mBitsValid = 0;
411 
412     CoordEq temp = *this;
413 
414     x = y = z = s = m = 0;
415 
416     UINT_32 bitsLeft = 0;
417 
418     for (UINT_32 i = 0; i < temp.m_numBits; i++)
419     {
420         UINT_32 termSize = temp.m_eq[i].getsize();
421 
422         if (termSize == 1)
423         {
424             INT_8 bit = (addr >> i) & 1;
425             INT_8 dim = temp.m_eq[i][0].getdim();
426             INT_8 ord = temp.m_eq[i][0].getord();
427 
428             ADDR_ASSERT((ord < 32) || (bit == 0));
429 
430             switch (dim)
431             {
432                 case 'x':
433                     xBitsValid |= (1 << ord);
434                     x |= (bit << ord);
435                     break;
436                 case 'y':
437                     yBitsValid |= (1 << ord);
438                     y |= (bit << ord);
439                     break;
440                 case 'z':
441                     zBitsValid |= (1 << ord);
442                     z |= (bit << ord);
443                     break;
444                 case 's':
445                     sBitsValid |= (1 << ord);
446                     s |= (bit << ord);
447                     break;
448                 case 'm':
449                     mBitsValid |= (1 << ord);
450                     m |= (bit << ord);
451                     break;
452                 default:
453                     break;
454             }
455 
456             temp.m_eq[i].Clear();
457         }
458         else if (termSize > 1)
459         {
460             bitsLeft++;
461         }
462     }
463 
464     if (bitsLeft > 0)
465     {
466         if (sliceInM != 0)
467         {
468             z = m / sliceInM;
469             zBitsValid = 0xffffffff;
470         }
471 
472         do
473         {
474             bitsLeft = 0;
475 
476             for (UINT_32 i = 0; i < temp.m_numBits; i++)
477             {
478                 UINT_32 termSize = temp.m_eq[i].getsize();
479 
480                 if (termSize == 1)
481                 {
482                     INT_8 bit = (addr >> i) & 1;
483                     INT_8 dim = temp.m_eq[i][0].getdim();
484                     INT_8 ord = temp.m_eq[i][0].getord();
485 
486                     ADDR_ASSERT((ord < 32) || (bit == 0));
487 
488                     switch (dim)
489                     {
490                         case 'x':
491                             xBitsValid |= (1 << ord);
492                             x |= (bit << ord);
493                             break;
494                         case 'y':
495                             yBitsValid |= (1 << ord);
496                             y |= (bit << ord);
497                             break;
498                         case 'z':
499                             zBitsValid |= (1 << ord);
500                             z |= (bit << ord);
501                             break;
502                         case 's':
503                             ADDR_ASSERT_ALWAYS();
504                             break;
505                         case 'm':
506                             ADDR_ASSERT_ALWAYS();
507                             break;
508                         default:
509                             break;
510                     }
511 
512                     temp.m_eq[i].Clear();
513                 }
514                 else if (termSize > 1)
515                 {
516                     CoordTerm tmpTerm = temp.m_eq[i];
517 
518                     for (UINT_32 j = 0; j < termSize; j++)
519                     {
520                         INT_8 dim = temp.m_eq[i][j].getdim();
521                         INT_8 ord = temp.m_eq[i][j].getord();
522 
523                         switch (dim)
524                         {
525                             case 'x':
526                                 if (xBitsValid & (1 << ord))
527                                 {
528                                     UINT_32 v = (((x >> ord) & 1) << i);
529                                     addr ^= static_cast<UINT_64>(v);
530                                     tmpTerm.remove(temp.m_eq[i][j]);
531                                 }
532                                 break;
533                             case 'y':
534                                 if (yBitsValid & (1 << ord))
535                                 {
536                                     UINT_32 v = (((y >> ord) & 1) << i);
537                                     addr ^= static_cast<UINT_64>(v);
538                                     tmpTerm.remove(temp.m_eq[i][j]);
539                                 }
540                                 break;
541                             case 'z':
542                                 if (zBitsValid & (1 << ord))
543                                 {
544                                     UINT_32 v = (((z >> ord) & 1) << i);
545                                     addr ^= static_cast<UINT_64>(v);
546                                     tmpTerm.remove(temp.m_eq[i][j]);
547                                 }
548                                 break;
549                             case 's':
550                                 ADDR_ASSERT_ALWAYS();
551                                 break;
552                             case 'm':
553                                 ADDR_ASSERT_ALWAYS();
554                                 break;
555                             default:
556                                 break;
557                         }
558                     }
559 
560                     temp.m_eq[i] = tmpTerm;
561 
562                     bitsLeft++;
563                 }
564             }
565         } while (bitsLeft > 0);
566     }
567 }
568 
copy(CoordEq & o,UINT_32 start,UINT_32 num)569 VOID CoordEq::copy(CoordEq& o, UINT_32 start, UINT_32 num)
570 {
571     o.m_numBits = (num == 0xFFFFFFFF) ? m_numBits : num;
572     for (UINT_32 i = 0; i < o.m_numBits; i++)
573     {
574         m_eq[start + i].copyto(o.m_eq[i]);
575     }
576 }
577 
reverse(UINT_32 start,UINT_32 num)578 VOID CoordEq::reverse(UINT_32 start, UINT_32 num)
579 {
580     UINT_32 n = (num == 0xFFFFFFFF) ? m_numBits : num;
581 
582     for (UINT_32 i = 0; i < n / 2; i++)
583     {
584         CoordTerm temp;
585         m_eq[start + i].copyto(temp);
586         m_eq[start + n - 1 - i].copyto(m_eq[start + i]);
587         temp.copyto(m_eq[start + n - 1 - i]);
588     }
589 }
590 
xorin(CoordEq & x,UINT_32 start)591 VOID CoordEq::xorin(CoordEq& x, UINT_32 start)
592 {
593     UINT_32 n = ((m_numBits - start) < x.m_numBits) ? (m_numBits - start) : x.m_numBits;
594     for (UINT_32 i = 0; i < n; i++)
595     {
596         m_eq[start + i].add(x.m_eq[i]);
597     }
598 }
599 
Filter(INT_8 f,Coordinate & co,UINT_32 start,INT_8 axis)600 UINT_32 CoordEq::Filter(INT_8 f, Coordinate& co, UINT_32 start, INT_8 axis)
601 {
602     for (UINT_32 i = start; i < m_numBits;)
603     {
604         UINT_32 m = m_eq[i].Filter(f, co, 0, axis);
605         if (m == 0)
606         {
607             for (UINT_32 j = i; j < m_numBits - 1; j++)
608             {
609                 m_eq[j] = m_eq[j + 1];
610             }
611             m_numBits--;
612         }
613         else
614         {
615             i++;
616         }
617     }
618     return m_numBits;
619 }
620 
shift(INT_32 amount,INT_32 start)621 VOID CoordEq::shift(INT_32 amount, INT_32 start)
622 {
623     if (amount != 0)
624     {
625         INT_32 numBits = static_cast<INT_32>(m_numBits);
626         amount = -amount;
627         INT_32 inc = (amount < 0) ? -1 : 1;
628         INT_32 i = (amount < 0) ? numBits - 1 : start;
629         INT_32 end = (amount < 0) ? start - 1 : numBits;
630         for (; (inc > 0) ? i < end : i > end; i += inc)
631         {
632             if ((i + amount < start) || (i + amount >= numBits))
633             {
634                 m_eq[i].Clear();
635             }
636             else
637             {
638                 m_eq[i + amount].copyto(m_eq[i]);
639             }
640         }
641     }
642 }
643 
operator [](UINT_32 i)644 CoordTerm& CoordEq::operator[](UINT_32 i)
645 {
646     return m_eq[i];
647 }
648 
mort2d(Coordinate & c0,Coordinate & c1,UINT_32 start,UINT_32 end)649 VOID CoordEq::mort2d(Coordinate& c0, Coordinate& c1, UINT_32 start, UINT_32 end)
650 {
651     if (end == 0)
652     {
653         ADDR_ASSERT(m_numBits > 0);
654         end = m_numBits - 1;
655     }
656     for (UINT_32 i = start; i <= end; i++)
657     {
658         UINT_32 select = (i - start) % 2;
659         Coordinate& c = (select == 0) ? c0 : c1;
660         m_eq[i].add(c);
661         c++;
662     }
663 }
664 
mort3d(Coordinate & c0,Coordinate & c1,Coordinate & c2,UINT_32 start,UINT_32 end)665 VOID CoordEq::mort3d(Coordinate& c0, Coordinate& c1, Coordinate& c2, UINT_32 start, UINT_32 end)
666 {
667     if (end == 0)
668     {
669         ADDR_ASSERT(m_numBits > 0);
670         end = m_numBits - 1;
671     }
672     for (UINT_32 i = start; i <= end; i++)
673     {
674         UINT_32 select = (i - start) % 3;
675         Coordinate& c = (select == 0) ? c0 : ((select == 1) ? c1 : c2);
676         m_eq[i].add(c);
677         c++;
678     }
679 }
680 
operator ==(const CoordEq & b)681 BOOL_32 CoordEq::operator==(const CoordEq& b)
682 {
683     BOOL_32 ret = TRUE;
684 
685     if (m_numBits != b.m_numBits)
686     {
687         ret = FALSE;
688     }
689     else
690     {
691         for (UINT_32 i = 0; i < m_numBits; i++)
692         {
693             if (m_eq[i] != b.m_eq[i])
694             {
695                 ret = FALSE;
696                 break;
697             }
698         }
699     }
700     return ret;
701 }
702 
operator !=(const CoordEq & b)703 BOOL_32 CoordEq::operator!=(const CoordEq& b)
704 {
705     return !(*this == b);
706 }
707 
708