1 //===-- lib/Evaluate/shape.cpp --------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Evaluate/shape.h"
10 #include "flang/Common/idioms.h"
11 #include "flang/Common/template.h"
12 #include "flang/Evaluate/characteristics.h"
13 #include "flang/Evaluate/fold.h"
14 #include "flang/Evaluate/intrinsics.h"
15 #include "flang/Evaluate/tools.h"
16 #include "flang/Evaluate/type.h"
17 #include "flang/Parser/message.h"
18 #include "flang/Semantics/symbol.h"
19 #include <functional>
20 
21 using namespace std::placeholders; // _1, _2, &c. for std::bind()
22 
23 namespace Fortran::evaluate {
24 
IsImpliedShape(const Symbol & symbol0)25 bool IsImpliedShape(const Symbol &symbol0) {
26   const Symbol &symbol{ResolveAssociations(symbol0)};
27   const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()};
28   return symbol.attrs().test(semantics::Attr::PARAMETER) && details &&
29       details->shape().IsImpliedShape();
30 }
31 
IsExplicitShape(const Symbol & symbol0)32 bool IsExplicitShape(const Symbol &symbol0) {
33   const Symbol &symbol{ResolveAssociations(symbol0)};
34   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
35     const auto &shape{details->shape()};
36     return shape.Rank() == 0 || shape.IsExplicitShape(); // even if scalar
37   } else {
38     return false;
39   }
40 }
41 
AsShape(const Constant<ExtentType> & arrayConstant)42 Shape AsShape(const Constant<ExtentType> &arrayConstant) {
43   CHECK(arrayConstant.Rank() == 1);
44   Shape result;
45   std::size_t dimensions{arrayConstant.size()};
46   for (std::size_t j{0}; j < dimensions; ++j) {
47     Scalar<ExtentType> extent{arrayConstant.values().at(j)};
48     result.emplace_back(MaybeExtentExpr{ExtentExpr{extent}});
49   }
50   return result;
51 }
52 
AsShape(FoldingContext & context,ExtentExpr && arrayExpr)53 std::optional<Shape> AsShape(FoldingContext &context, ExtentExpr &&arrayExpr) {
54   // Flatten any array expression into an array constructor if possible.
55   arrayExpr = Fold(context, std::move(arrayExpr));
56   if (const auto *constArray{UnwrapConstantValue<ExtentType>(arrayExpr)}) {
57     return AsShape(*constArray);
58   }
59   if (auto *constructor{UnwrapExpr<ArrayConstructor<ExtentType>>(arrayExpr)}) {
60     Shape result;
61     for (auto &value : *constructor) {
62       if (auto *expr{std::get_if<ExtentExpr>(&value.u)}) {
63         if (expr->Rank() == 0) {
64           result.emplace_back(std::move(*expr));
65           continue;
66         }
67       }
68       return std::nullopt;
69     }
70     return result;
71   }
72   return std::nullopt;
73 }
74 
AsExtentArrayExpr(const Shape & shape)75 std::optional<ExtentExpr> AsExtentArrayExpr(const Shape &shape) {
76   ArrayConstructorValues<ExtentType> values;
77   for (const auto &dim : shape) {
78     if (dim) {
79       values.Push(common::Clone(*dim));
80     } else {
81       return std::nullopt;
82     }
83   }
84   return ExtentExpr{ArrayConstructor<ExtentType>{std::move(values)}};
85 }
86 
AsConstantShape(FoldingContext & context,const Shape & shape)87 std::optional<Constant<ExtentType>> AsConstantShape(
88     FoldingContext &context, const Shape &shape) {
89   if (auto shapeArray{AsExtentArrayExpr(shape)}) {
90     auto folded{Fold(context, std::move(*shapeArray))};
91     if (auto *p{UnwrapConstantValue<ExtentType>(folded)}) {
92       return std::move(*p);
93     }
94   }
95   return std::nullopt;
96 }
97 
AsConstantShape(const ConstantSubscripts & shape)98 Constant<SubscriptInteger> AsConstantShape(const ConstantSubscripts &shape) {
99   using IntType = Scalar<SubscriptInteger>;
100   std::vector<IntType> result;
101   for (auto dim : shape) {
102     result.emplace_back(dim);
103   }
104   return {std::move(result), ConstantSubscripts{GetRank(shape)}};
105 }
106 
AsConstantExtents(const Constant<ExtentType> & shape)107 ConstantSubscripts AsConstantExtents(const Constant<ExtentType> &shape) {
108   ConstantSubscripts result;
109   for (const auto &extent : shape.values()) {
110     result.push_back(extent.ToInt64());
111   }
112   return result;
113 }
114 
AsConstantExtents(FoldingContext & context,const Shape & shape)115 std::optional<ConstantSubscripts> AsConstantExtents(
116     FoldingContext &context, const Shape &shape) {
117   if (auto shapeConstant{AsConstantShape(context, shape)}) {
118     return AsConstantExtents(*shapeConstant);
119   } else {
120     return std::nullopt;
121   }
122 }
123 
ComputeTripCount(FoldingContext & context,ExtentExpr && lower,ExtentExpr && upper,ExtentExpr && stride)124 static ExtentExpr ComputeTripCount(FoldingContext &context, ExtentExpr &&lower,
125     ExtentExpr &&upper, ExtentExpr &&stride) {
126   ExtentExpr strideCopy{common::Clone(stride)};
127   ExtentExpr span{
128       (std::move(upper) - std::move(lower) + std::move(strideCopy)) /
129       std::move(stride)};
130   ExtentExpr extent{
131       Extremum<ExtentType>{Ordering::Greater, std::move(span), ExtentExpr{0}}};
132   return Fold(context, std::move(extent));
133 }
134 
CountTrips(FoldingContext & context,ExtentExpr && lower,ExtentExpr && upper,ExtentExpr && stride)135 ExtentExpr CountTrips(FoldingContext &context, ExtentExpr &&lower,
136     ExtentExpr &&upper, ExtentExpr &&stride) {
137   return ComputeTripCount(
138       context, std::move(lower), std::move(upper), std::move(stride));
139 }
140 
CountTrips(FoldingContext & context,const ExtentExpr & lower,const ExtentExpr & upper,const ExtentExpr & stride)141 ExtentExpr CountTrips(FoldingContext &context, const ExtentExpr &lower,
142     const ExtentExpr &upper, const ExtentExpr &stride) {
143   return ComputeTripCount(context, common::Clone(lower), common::Clone(upper),
144       common::Clone(stride));
145 }
146 
CountTrips(FoldingContext & context,MaybeExtentExpr && lower,MaybeExtentExpr && upper,MaybeExtentExpr && stride)147 MaybeExtentExpr CountTrips(FoldingContext &context, MaybeExtentExpr &&lower,
148     MaybeExtentExpr &&upper, MaybeExtentExpr &&stride) {
149   std::function<ExtentExpr(ExtentExpr &&, ExtentExpr &&, ExtentExpr &&)> bound{
150       std::bind(ComputeTripCount, context, _1, _2, _3)};
151   return common::MapOptional(
152       std::move(bound), std::move(lower), std::move(upper), std::move(stride));
153 }
154 
GetSize(Shape && shape)155 MaybeExtentExpr GetSize(Shape &&shape) {
156   ExtentExpr extent{1};
157   for (auto &&dim : std::move(shape)) {
158     if (dim) {
159       extent = std::move(extent) * std::move(*dim);
160     } else {
161       return std::nullopt;
162     }
163   }
164   return extent;
165 }
166 
ContainsAnyImpliedDoIndex(const ExtentExpr & expr)167 bool ContainsAnyImpliedDoIndex(const ExtentExpr &expr) {
168   struct MyVisitor : public AnyTraverse<MyVisitor> {
169     using Base = AnyTraverse<MyVisitor>;
170     MyVisitor() : Base{*this} {}
171     using Base::operator();
172     bool operator()(const ImpliedDoIndex &) { return true; }
173   };
174   return MyVisitor{}(expr);
175 }
176 
177 // Determines lower bound on a dimension.  This can be other than 1 only
178 // for a reference to a whole array object or component. (See LBOUND, 16.9.109).
179 // ASSOCIATE construct entities may require tranversal of their referents.
180 class GetLowerBoundHelper : public Traverse<GetLowerBoundHelper, ExtentExpr> {
181 public:
182   using Result = ExtentExpr;
183   using Base = Traverse<GetLowerBoundHelper, ExtentExpr>;
184   using Base::operator();
GetLowerBoundHelper(FoldingContext & c,int d)185   GetLowerBoundHelper(FoldingContext &c, int d)
186       : Base{*this}, context_{c}, dimension_{d} {}
Default()187   static ExtentExpr Default() { return ExtentExpr{1}; }
Combine(Result &&,Result &&)188   static ExtentExpr Combine(Result &&, Result &&) { return Default(); }
189   ExtentExpr operator()(const Symbol &);
190   ExtentExpr operator()(const Component &);
191 
192 private:
193   FoldingContext &context_;
194   int dimension_;
195 };
196 
operator ()(const Symbol & symbol0)197 auto GetLowerBoundHelper::operator()(const Symbol &symbol0) -> Result {
198   const Symbol &symbol{symbol0.GetUltimate()};
199   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
200     int j{0};
201     for (const auto &shapeSpec : details->shape()) {
202       if (j++ == dimension_) {
203         if (const auto &bound{shapeSpec.lbound().GetExplicit()}) {
204           return Fold(context_, common::Clone(*bound));
205         } else if (IsDescriptor(symbol)) {
206           return ExtentExpr{DescriptorInquiry{NamedEntity{symbol0},
207               DescriptorInquiry::Field::LowerBound, dimension_}};
208         } else {
209           break;
210         }
211       }
212     }
213   } else if (const auto *assoc{
214                  symbol.detailsIf<semantics::AssocEntityDetails>()}) {
215     return (*this)(assoc->expr());
216   }
217   return Default();
218 }
219 
operator ()(const Component & component)220 auto GetLowerBoundHelper::operator()(const Component &component) -> Result {
221   if (component.base().Rank() == 0) {
222     const Symbol &symbol{component.GetLastSymbol().GetUltimate()};
223     if (const auto *details{
224             symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
225       int j{0};
226       for (const auto &shapeSpec : details->shape()) {
227         if (j++ == dimension_) {
228           if (const auto &bound{shapeSpec.lbound().GetExplicit()}) {
229             return Fold(context_, common::Clone(*bound));
230           } else if (IsDescriptor(symbol)) {
231             return ExtentExpr{
232                 DescriptorInquiry{NamedEntity{common::Clone(component)},
233                     DescriptorInquiry::Field::LowerBound, dimension_}};
234           } else {
235             break;
236           }
237         }
238       }
239     }
240   }
241   return Default();
242 }
243 
GetLowerBound(FoldingContext & context,const NamedEntity & base,int dimension)244 ExtentExpr GetLowerBound(
245     FoldingContext &context, const NamedEntity &base, int dimension) {
246   return GetLowerBoundHelper{context, dimension}(base);
247 }
248 
GetLowerBounds(FoldingContext & context,const NamedEntity & base)249 Shape GetLowerBounds(FoldingContext &context, const NamedEntity &base) {
250   Shape result;
251   int rank{base.Rank()};
252   for (int dim{0}; dim < rank; ++dim) {
253     result.emplace_back(GetLowerBound(context, base, dim));
254   }
255   return result;
256 }
257 
GetExtent(FoldingContext & context,const NamedEntity & base,int dimension)258 MaybeExtentExpr GetExtent(
259     FoldingContext &context, const NamedEntity &base, int dimension) {
260   CHECK(dimension >= 0);
261   const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
262   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
263     if (IsImpliedShape(symbol)) {
264       Shape shape{GetShape(context, symbol).value()};
265       return std::move(shape.at(dimension));
266     }
267     int j{0};
268     for (const auto &shapeSpec : details->shape()) {
269       if (j++ == dimension) {
270         if (shapeSpec.ubound().isExplicit()) {
271           if (const auto &ubound{shapeSpec.ubound().GetExplicit()}) {
272             if (const auto &lbound{shapeSpec.lbound().GetExplicit()}) {
273               return Fold(context,
274                   common::Clone(ubound.value()) -
275                       common::Clone(lbound.value()) + ExtentExpr{1});
276             } else {
277               return Fold(context, common::Clone(ubound.value()));
278             }
279           }
280         } else if (details->IsAssumedSize() && j == symbol.Rank()) {
281           return std::nullopt;
282         } else if (semantics::IsDescriptor(symbol)) {
283           return ExtentExpr{DescriptorInquiry{
284               NamedEntity{base}, DescriptorInquiry::Field::Extent, dimension}};
285         }
286       }
287     }
288   } else if (const auto *assoc{
289                  symbol.detailsIf<semantics::AssocEntityDetails>()}) {
290     if (auto shape{GetShape(context, assoc->expr())}) {
291       if (dimension < static_cast<int>(shape->size())) {
292         return std::move(shape->at(dimension));
293       }
294     }
295   }
296   return std::nullopt;
297 }
298 
GetExtent(FoldingContext & context,const Subscript & subscript,const NamedEntity & base,int dimension)299 MaybeExtentExpr GetExtent(FoldingContext &context, const Subscript &subscript,
300     const NamedEntity &base, int dimension) {
301   return std::visit(
302       common::visitors{
303           [&](const Triplet &triplet) -> MaybeExtentExpr {
304             MaybeExtentExpr upper{triplet.upper()};
305             if (!upper) {
306               upper = GetUpperBound(context, base, dimension);
307             }
308             MaybeExtentExpr lower{triplet.lower()};
309             if (!lower) {
310               lower = GetLowerBound(context, base, dimension);
311             }
312             return CountTrips(context, std::move(lower), std::move(upper),
313                 MaybeExtentExpr{triplet.stride()});
314           },
315           [&](const IndirectSubscriptIntegerExpr &subs) -> MaybeExtentExpr {
316             if (auto shape{GetShape(context, subs.value())}) {
317               if (GetRank(*shape) > 0) {
318                 CHECK(GetRank(*shape) == 1); // vector-valued subscript
319                 return std::move(shape->at(0));
320               }
321             }
322             return std::nullopt;
323           },
324       },
325       subscript.u);
326 }
327 
ComputeUpperBound(FoldingContext & context,ExtentExpr && lower,MaybeExtentExpr && extent)328 MaybeExtentExpr ComputeUpperBound(
329     FoldingContext &context, ExtentExpr &&lower, MaybeExtentExpr &&extent) {
330   if (extent) {
331     return Fold(context, std::move(*extent) - std::move(lower) + ExtentExpr{1});
332   } else {
333     return std::nullopt;
334   }
335 }
336 
GetUpperBound(FoldingContext & context,const NamedEntity & base,int dimension)337 MaybeExtentExpr GetUpperBound(
338     FoldingContext &context, const NamedEntity &base, int dimension) {
339   const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
340   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
341     int j{0};
342     for (const auto &shapeSpec : details->shape()) {
343       if (j++ == dimension) {
344         if (const auto &bound{shapeSpec.ubound().GetExplicit()}) {
345           return Fold(context, common::Clone(*bound));
346         } else if (details->IsAssumedSize() && dimension + 1 == symbol.Rank()) {
347           break;
348         } else {
349           return ComputeUpperBound(context,
350               GetLowerBound(context, base, dimension),
351               GetExtent(context, base, dimension));
352         }
353       }
354     }
355   } else if (const auto *assoc{
356                  symbol.detailsIf<semantics::AssocEntityDetails>()}) {
357     if (auto shape{GetShape(context, assoc->expr())}) {
358       if (dimension < static_cast<int>(shape->size())) {
359         return ComputeUpperBound(context,
360             GetLowerBound(context, base, dimension),
361             std::move(shape->at(dimension)));
362       }
363     }
364   }
365   return std::nullopt;
366 }
367 
GetUpperBounds(FoldingContext & context,const NamedEntity & base)368 Shape GetUpperBounds(FoldingContext &context, const NamedEntity &base) {
369   const Symbol &symbol{ResolveAssociations(base.GetLastSymbol())};
370   if (const auto *details{symbol.detailsIf<semantics::ObjectEntityDetails>()}) {
371     Shape result;
372     int dim{0};
373     for (const auto &shapeSpec : details->shape()) {
374       if (const auto &bound{shapeSpec.ubound().GetExplicit()}) {
375         result.emplace_back(Fold(context, common::Clone(*bound)));
376       } else if (details->IsAssumedSize()) {
377         CHECK(dim + 1 == base.Rank());
378         result.emplace_back(std::nullopt); // UBOUND folding replaces with -1
379       } else {
380         result.emplace_back(ComputeUpperBound(context,
381             GetLowerBound(context, base, dim), GetExtent(context, base, dim)));
382       }
383       ++dim;
384     }
385     CHECK(GetRank(result) == symbol.Rank());
386     return result;
387   } else {
388     return std::move(GetShape(context, base).value());
389   }
390 }
391 
operator ()(const Symbol & symbol) const392 auto GetShapeHelper::operator()(const Symbol &symbol) const -> Result {
393   return std::visit(
394       common::visitors{
395           [&](const semantics::ObjectEntityDetails &object) {
396             if (IsImpliedShape(symbol)) {
397               return (*this)(object.init());
398             } else {
399               int n{object.shape().Rank()};
400               NamedEntity base{symbol};
401               return Result{CreateShape(n, base)};
402             }
403           },
404           [](const semantics::EntityDetails &) {
405             return Scalar(); // no dimensions seen
406           },
407           [&](const semantics::ProcEntityDetails &proc) {
408             if (const Symbol * interface{proc.interface().symbol()}) {
409               return (*this)(*interface);
410             } else {
411               return Scalar();
412             }
413           },
414           [&](const semantics::AssocEntityDetails &assoc) {
415             if (!assoc.rank()) {
416               return (*this)(assoc.expr());
417             } else {
418               int n{assoc.rank().value()};
419               NamedEntity base{symbol};
420               return Result{CreateShape(n, base)};
421             }
422           },
423           [&](const semantics::SubprogramDetails &subp) {
424             if (subp.isFunction()) {
425               return (*this)(subp.result());
426             } else {
427               return Result{};
428             }
429           },
430           [&](const semantics::ProcBindingDetails &binding) {
431             return (*this)(binding.symbol());
432           },
433           [&](const semantics::UseDetails &use) {
434             return (*this)(use.symbol());
435           },
436           [&](const semantics::HostAssocDetails &assoc) {
437             return (*this)(assoc.symbol());
438           },
439           [](const semantics::TypeParamDetails &) { return Scalar(); },
440           [](const auto &) { return Result{}; },
441       },
442       symbol.details());
443 }
444 
operator ()(const Component & component) const445 auto GetShapeHelper::operator()(const Component &component) const -> Result {
446   const Symbol &symbol{component.GetLastSymbol()};
447   int rank{symbol.Rank()};
448   if (rank == 0) {
449     return (*this)(component.base());
450   } else if (symbol.has<semantics::ObjectEntityDetails>()) {
451     NamedEntity base{Component{component}};
452     return CreateShape(rank, base);
453   } else if (symbol.has<semantics::AssocEntityDetails>()) {
454     NamedEntity base{Component{component}};
455     return Result{CreateShape(rank, base)};
456   } else {
457     return (*this)(symbol);
458   }
459 }
460 
operator ()(const ArrayRef & arrayRef) const461 auto GetShapeHelper::operator()(const ArrayRef &arrayRef) const -> Result {
462   Shape shape;
463   int dimension{0};
464   const NamedEntity &base{arrayRef.base()};
465   for (const Subscript &ss : arrayRef.subscript()) {
466     if (ss.Rank() > 0) {
467       shape.emplace_back(GetExtent(context_, ss, base, dimension));
468     }
469     ++dimension;
470   }
471   if (shape.empty()) {
472     if (const Component * component{base.UnwrapComponent()}) {
473       return (*this)(component->base());
474     }
475   }
476   return shape;
477 }
478 
operator ()(const CoarrayRef & coarrayRef) const479 auto GetShapeHelper::operator()(const CoarrayRef &coarrayRef) const -> Result {
480   NamedEntity base{coarrayRef.GetBase()};
481   if (coarrayRef.subscript().empty()) {
482     return (*this)(base);
483   } else {
484     Shape shape;
485     int dimension{0};
486     for (const Subscript &ss : coarrayRef.subscript()) {
487       if (ss.Rank() > 0) {
488         shape.emplace_back(GetExtent(context_, ss, base, dimension));
489       }
490       ++dimension;
491     }
492     return shape;
493   }
494 }
495 
operator ()(const Substring & substring) const496 auto GetShapeHelper::operator()(const Substring &substring) const -> Result {
497   return (*this)(substring.parent());
498 }
499 
operator ()(const ProcedureRef & call) const500 auto GetShapeHelper::operator()(const ProcedureRef &call) const -> Result {
501   if (call.Rank() == 0) {
502     return Scalar();
503   } else if (call.IsElemental()) {
504     for (const auto &arg : call.arguments()) {
505       if (arg && arg->Rank() > 0) {
506         return (*this)(*arg);
507       }
508     }
509     return Scalar();
510   } else if (const Symbol * symbol{call.proc().GetSymbol()}) {
511     return (*this)(*symbol);
512   } else if (const auto *intrinsic{call.proc().GetSpecificIntrinsic()}) {
513     if (intrinsic->name == "shape" || intrinsic->name == "lbound" ||
514         intrinsic->name == "ubound") {
515       // These are the array-valued cases for LBOUND and UBOUND (no DIM=).
516       const auto *expr{call.arguments().front().value().UnwrapExpr()};
517       CHECK(expr);
518       return Shape{MaybeExtentExpr{ExtentExpr{expr->Rank()}}};
519     } else if (intrinsic->name == "all" || intrinsic->name == "any" ||
520         intrinsic->name == "count" || intrinsic->name == "iall" ||
521         intrinsic->name == "iany" || intrinsic->name == "iparity" ||
522         intrinsic->name == "maxloc" || intrinsic->name == "maxval" ||
523         intrinsic->name == "minloc" || intrinsic->name == "minval" ||
524         intrinsic->name == "norm2" || intrinsic->name == "parity" ||
525         intrinsic->name == "product" || intrinsic->name == "sum") {
526       // Reduction with DIM=
527       if (call.arguments().size() >= 2) {
528         auto arrayShape{
529             (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
530         const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
531         if (arrayShape && dimArg) {
532           if (auto dim{ToInt64(*dimArg)}) {
533             if (*dim >= 1 &&
534                 static_cast<std::size_t>(*dim) <= arrayShape->size()) {
535               arrayShape->erase(arrayShape->begin() + (*dim - 1));
536               return std::move(*arrayShape);
537             }
538           }
539         }
540       }
541     } else if (intrinsic->name == "cshift" || intrinsic->name == "eoshift") {
542       if (!call.arguments().empty()) {
543         return (*this)(call.arguments()[0]);
544       }
545     } else if (intrinsic->name == "matmul") {
546       if (call.arguments().size() == 2) {
547         if (auto ashape{(*this)(call.arguments()[0])}) {
548           if (auto bshape{(*this)(call.arguments()[1])}) {
549             if (ashape->size() == 1 && bshape->size() == 2) {
550               bshape->erase(bshape->begin());
551               return std::move(*bshape); // matmul(vector, matrix)
552             } else if (ashape->size() == 2 && bshape->size() == 1) {
553               ashape->pop_back();
554               return std::move(*ashape); // matmul(matrix, vector)
555             } else if (ashape->size() == 2 && bshape->size() == 2) {
556               (*ashape)[1] = std::move((*bshape)[1]);
557               return std::move(*ashape); // matmul(matrix, matrix)
558             }
559           }
560         }
561       }
562     } else if (intrinsic->name == "reshape") {
563       if (call.arguments().size() >= 2 && call.arguments().at(1)) {
564         // SHAPE(RESHAPE(array,shape)) -> shape
565         if (const auto *shapeExpr{
566                 call.arguments().at(1).value().UnwrapExpr()}) {
567           auto shape{std::get<Expr<SomeInteger>>(shapeExpr->u)};
568           return AsShape(context_, ConvertToType<ExtentType>(std::move(shape)));
569         }
570       }
571     } else if (intrinsic->name == "pack") {
572       if (call.arguments().size() >= 3 && call.arguments().at(2)) {
573         // SHAPE(PACK(,,VECTOR=v)) -> SHAPE(v)
574         return (*this)(call.arguments().at(2));
575       } else if (call.arguments().size() >= 2) {
576         if (auto maskShape{(*this)(call.arguments().at(1))}) {
577           if (maskShape->size() == 0) {
578             // Scalar MASK= -> [MERGE(SIZE(ARRAY=), 0, mask)]
579             if (auto arrayShape{(*this)(call.arguments().at(0))}) {
580               auto arraySize{GetSize(std::move(*arrayShape))};
581               CHECK(arraySize);
582               ActualArguments toMerge{
583                   ActualArgument{AsGenericExpr(std::move(*arraySize))},
584                   ActualArgument{AsGenericExpr(ExtentExpr{0})},
585                   common::Clone(call.arguments().at(1))};
586               auto specific{context_.intrinsics().Probe(
587                   CallCharacteristics{"merge"}, toMerge, context_)};
588               CHECK(specific);
589               return Shape{ExtentExpr{FunctionRef<ExtentType>{
590                   ProcedureDesignator{std::move(specific->specificIntrinsic)},
591                   std::move(specific->arguments)}}};
592             }
593           } else {
594             // Non-scalar MASK= -> [COUNT(mask)]
595             ActualArguments toCount{ActualArgument{common::Clone(
596                 DEREF(call.arguments().at(1).value().UnwrapExpr()))}};
597             auto specific{context_.intrinsics().Probe(
598                 CallCharacteristics{"count"}, toCount, context_)};
599             CHECK(specific);
600             return Shape{ExtentExpr{FunctionRef<ExtentType>{
601                 ProcedureDesignator{std::move(specific->specificIntrinsic)},
602                 std::move(specific->arguments)}}};
603           }
604         }
605       }
606     } else if (intrinsic->name == "spread") {
607       // SHAPE(SPREAD(ARRAY,DIM,NCOPIES)) = SHAPE(ARRAY) with NCOPIES inserted
608       // at position DIM.
609       if (call.arguments().size() == 3) {
610         auto arrayShape{
611             (*this)(UnwrapExpr<Expr<SomeType>>(call.arguments().at(0)))};
612         const auto *dimArg{UnwrapExpr<Expr<SomeType>>(call.arguments().at(1))};
613         const auto *nCopies{
614             UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))};
615         if (arrayShape && dimArg && nCopies) {
616           if (auto dim{ToInt64(*dimArg)}) {
617             if (*dim >= 1 &&
618                 static_cast<std::size_t>(*dim) <= arrayShape->size() + 1) {
619               arrayShape->emplace(arrayShape->begin() + *dim - 1,
620                   ConvertToType<ExtentType>(common::Clone(*nCopies)));
621               return std::move(*arrayShape);
622             }
623           }
624         }
625       }
626     } else if (intrinsic->name == "transfer") {
627       if (call.arguments().size() == 3 && call.arguments().at(2)) {
628         // SIZE= is present; shape is vector [SIZE=]
629         if (const auto *size{
630                 UnwrapExpr<Expr<SomeInteger>>(call.arguments().at(2))}) {
631           return Shape{
632               MaybeExtentExpr{ConvertToType<ExtentType>(common::Clone(*size))}};
633         }
634       } else if (auto moldTypeAndShape{
635                      characteristics::TypeAndShape::Characterize(
636                          call.arguments().at(1), context_)}) {
637         if (GetRank(moldTypeAndShape->shape()) == 0) {
638           // SIZE= is absent and MOLD= is scalar: result is scalar
639           return Scalar();
640         } else {
641           // SIZE= is absent and MOLD= is array: result is vector whose
642           // length is determined by sizes of types.  See 16.9.193p4 case(ii).
643           if (auto sourceTypeAndShape{
644                   characteristics::TypeAndShape::Characterize(
645                       call.arguments().at(0), context_)}) {
646             auto sourceElements{
647                 GetSize(common::Clone(sourceTypeAndShape->shape()))};
648             auto sourceElementBytes{
649                 sourceTypeAndShape->MeasureSizeInBytes(&context_)};
650             auto moldElementBytes{
651                 moldTypeAndShape->MeasureSizeInBytes(&context_)};
652             if (sourceElements && sourceElementBytes && moldElementBytes) {
653               ExtentExpr extent{Fold(context_,
654                   ((std::move(*sourceElements) *
655                        std::move(*sourceElementBytes)) +
656                       common::Clone(*moldElementBytes) - ExtentExpr{1}) /
657                       common::Clone(*moldElementBytes))};
658               return Shape{MaybeExtentExpr{std::move(extent)}};
659             }
660           }
661         }
662       }
663     } else if (intrinsic->name == "transpose") {
664       if (call.arguments().size() >= 1) {
665         if (auto shape{(*this)(call.arguments().at(0))}) {
666           if (shape->size() == 2) {
667             std::swap((*shape)[0], (*shape)[1]);
668             return shape;
669           }
670         }
671       }
672     } else if (intrinsic->characteristics.value().attrs.test(characteristics::
673                        Procedure::Attr::NullPointer)) { // NULL(MOLD=)
674       return (*this)(call.arguments());
675     } else {
676       // TODO: shapes of other non-elemental intrinsic results
677     }
678   }
679   return std::nullopt;
680 }
681 
682 // Check conformance of the passed shapes.  Only return true if we can verify
683 // that they conform
CheckConformance(parser::ContextualMessages & messages,const Shape & left,const Shape & right,const char * leftIs,const char * rightIs,bool leftScalarExpandable,bool rightScalarExpandable)684 bool CheckConformance(parser::ContextualMessages &messages, const Shape &left,
685     const Shape &right, const char *leftIs, const char *rightIs,
686     bool leftScalarExpandable, bool rightScalarExpandable) {
687   int n{GetRank(left)};
688   if (n == 0 && leftScalarExpandable) {
689     return true;
690   }
691   int rn{GetRank(right)};
692   if (rn == 0 && rightScalarExpandable) {
693     return true;
694   }
695   if (n != rn) {
696     messages.Say("Rank of %1$s is %2$d, but %3$s has rank %4$d"_err_en_US,
697         leftIs, n, rightIs, rn);
698     return false;
699   }
700   for (int j{0}; j < n; ++j) {
701     auto leftDim{ToInt64(left[j])};
702     auto rightDim{ToInt64(right[j])};
703     if (!leftDim || !rightDim) {
704       return false;
705     }
706     if (*leftDim != *rightDim) {
707       messages.Say("Dimension %1$d of %2$s has extent %3$jd, "
708                    "but %4$s has extent %5$jd"_err_en_US,
709           j + 1, leftIs, *leftDim, rightIs, *rightDim);
710       return false;
711     }
712   }
713   return true;
714 }
715 
IncrementSubscripts(ConstantSubscripts & indices,const ConstantSubscripts & extents)716 bool IncrementSubscripts(
717     ConstantSubscripts &indices, const ConstantSubscripts &extents) {
718   std::size_t rank(indices.size());
719   CHECK(rank <= extents.size());
720   for (std::size_t j{0}; j < rank; ++j) {
721     if (extents[j] < 1) {
722       return false;
723     }
724   }
725   for (std::size_t j{0}; j < rank; ++j) {
726     if (indices[j]++ < extents[j]) {
727       return true;
728     }
729     indices[j] = 1;
730   }
731   return false;
732 }
733 } // namespace Fortran::evaluate
734