1 /*
<lambda>null2  * Copyright 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.compose.animation.scene
18 
19 import androidx.compose.runtime.Stable
20 import androidx.compose.runtime.getValue
21 import androidx.compose.runtime.mutableStateOf
22 import androidx.compose.runtime.setValue
23 import androidx.compose.runtime.snapshots.SnapshotStateMap
24 import androidx.compose.ui.ExperimentalComposeUiApi
25 import androidx.compose.ui.Modifier
26 import androidx.compose.ui.geometry.Offset
27 import androidx.compose.ui.geometry.isSpecified
28 import androidx.compose.ui.geometry.isUnspecified
29 import androidx.compose.ui.geometry.lerp
30 import androidx.compose.ui.graphics.CompositingStrategy
31 import androidx.compose.ui.graphics.drawscope.ContentDrawScope
32 import androidx.compose.ui.graphics.drawscope.scale
33 import androidx.compose.ui.layout.ApproachLayoutModifierNode
34 import androidx.compose.ui.layout.ApproachMeasureScope
35 import androidx.compose.ui.layout.LayoutCoordinates
36 import androidx.compose.ui.layout.Measurable
37 import androidx.compose.ui.layout.MeasureResult
38 import androidx.compose.ui.layout.MeasureScope
39 import androidx.compose.ui.layout.Placeable
40 import androidx.compose.ui.node.DrawModifierNode
41 import androidx.compose.ui.node.ModifierNodeElement
42 import androidx.compose.ui.node.TraversableNode
43 import androidx.compose.ui.node.traverseDescendants
44 import androidx.compose.ui.platform.testTag
45 import androidx.compose.ui.unit.Constraints
46 import androidx.compose.ui.unit.IntSize
47 import androidx.compose.ui.unit.round
48 import androidx.compose.ui.util.fastCoerceIn
49 import androidx.compose.ui.util.fastLastOrNull
50 import androidx.compose.ui.util.lerp
51 import com.android.compose.animation.scene.transformation.PropertyTransformation
52 import com.android.compose.animation.scene.transformation.SharedElementTransformation
53 import com.android.compose.ui.util.lerp
54 import kotlin.math.roundToInt
55 import kotlinx.coroutines.launch
56 
57 /** An element on screen, that can be composed in one or more scenes. */
58 @Stable
59 internal class Element(val key: ElementKey) {
60     /** The mapping between a scene and the state this element has in that scene, if any. */
61     // TODO(b/316901148): Make this a normal map instead once we can make sure that new transitions
62     // are first seen by composition then layout/drawing code. See b/316901148#comment2 for details.
63     val sceneStates = SnapshotStateMap<SceneKey, SceneState>()
64 
65     /**
66      * The last transition that was used when computing the state (size, position and alpha) of this
67      * element in any scene, or `null` if it was last laid out when idle.
68      */
69     var lastTransition: TransitionState.Transition? = null
70 
71     /** Whether this element was ever drawn in a scene. */
72     var wasDrawnInAnyScene = false
73 
74     override fun toString(): String {
75         return "Element(key=$key)"
76     }
77 
78     /** The last and target state of this element in a given scene. */
79     @Stable
80     class SceneState(val scene: SceneKey) {
81         /**
82          * The *target* state of this element in this scene, i.e. the state of this element when we
83          * are idle on this scene.
84          */
85         var targetSize by mutableStateOf(SizeUnspecified)
86         var targetOffset by mutableStateOf(Offset.Unspecified)
87 
88         /** The last state this element had in this scene. */
89         var lastOffset = Offset.Unspecified
90         var lastSize = SizeUnspecified
91         var lastScale = Scale.Unspecified
92         var lastAlpha = AlphaUnspecified
93 
94         /** The state of this element in this scene right before the last interruption (if any). */
95         var offsetBeforeInterruption = Offset.Unspecified
96         var sizeBeforeInterruption = SizeUnspecified
97         var scaleBeforeInterruption = Scale.Unspecified
98         var alphaBeforeInterruption = AlphaUnspecified
99 
100         /**
101          * The delta values to add to this element state to have smoother interruptions. These
102          * should be multiplied by the
103          * [current interruption progress][TransitionState.Transition.interruptionProgress] so that
104          * they nicely animate from their values down to 0.
105          */
106         var offsetInterruptionDelta = Offset.Zero
107         var sizeInterruptionDelta = IntSize.Zero
108         var scaleInterruptionDelta = Scale.Zero
109         var alphaInterruptionDelta = 0f
110 
111         /**
112          * The attached [ElementNode] a Modifier.element() for a given element and scene. During
113          * composition, this set could have 0 to 2 elements. After composition and after all
114          * modifier nodes have been attached/detached, this set should contain exactly 1 element.
115          */
116         val nodes = mutableSetOf<ElementNode>()
117     }
118 
119     companion object {
120         val SizeUnspecified = IntSize(Int.MAX_VALUE, Int.MAX_VALUE)
121         val AlphaUnspecified = Float.MAX_VALUE
122     }
123 }
124 
125 data class Scale(val scaleX: Float, val scaleY: Float, val pivot: Offset = Offset.Unspecified) {
126     companion object {
127         val Default = Scale(1f, 1f, Offset.Unspecified)
128         val Zero = Scale(0f, 0f, Offset.Zero)
129         val Unspecified = Scale(Float.MAX_VALUE, Float.MAX_VALUE, Offset.Unspecified)
130     }
131 }
132 
133 /** The implementation of [SceneScope.element]. */
134 @Stable
elementnull135 internal fun Modifier.element(
136     layoutImpl: SceneTransitionLayoutImpl,
137     scene: Scene,
138     key: ElementKey,
139 ): Modifier {
140     // Make sure that we read the current transitions during composition and not during
141     // layout/drawing.
142     // TODO(b/341072461): Revert this and read the current transitions in ElementNode directly once
143     // we can ensure that SceneTransitionLayoutImpl will compose new scenes first.
144     val currentTransitions = layoutImpl.state.currentTransitions
145     return then(ElementModifier(layoutImpl, currentTransitions, scene, key)).testTag(key.testTag)
146 }
147 
148 /**
149  * An element associated to [ElementNode]. Note that this element does not support updates as its
150  * arguments should always be the same.
151  */
152 private data class ElementModifier(
153     private val layoutImpl: SceneTransitionLayoutImpl,
154     private val currentTransitions: List<TransitionState.Transition>,
155     private val scene: Scene,
156     private val key: ElementKey,
157 ) : ModifierNodeElement<ElementNode>() {
createnull158     override fun create(): ElementNode = ElementNode(layoutImpl, currentTransitions, scene, key)
159 
160     override fun update(node: ElementNode) {
161         node.update(layoutImpl, currentTransitions, scene, key)
162     }
163 }
164 
165 internal class ElementNode(
166     private var layoutImpl: SceneTransitionLayoutImpl,
167     private var currentTransitions: List<TransitionState.Transition>,
168     private var scene: Scene,
169     private var key: ElementKey,
170 ) : Modifier.Node(), DrawModifierNode, ApproachLayoutModifierNode, TraversableNode {
171     private var _element: Element? = null
172     private val element: Element
173         get() = _element!!
174 
175     private var _sceneState: Element.SceneState? = null
176     private val sceneState: Element.SceneState
177         get() = _sceneState!!
178 
179     override val traverseKey: Any = ElementTraverseKey
180 
onAttachnull181     override fun onAttach() {
182         super.onAttach()
183         updateElementAndSceneValues()
184         addNodeToSceneState()
185     }
186 
updateElementAndSceneValuesnull187     private fun updateElementAndSceneValues() {
188         val element =
189             layoutImpl.elements[key] ?: Element(key).also { layoutImpl.elements[key] = it }
190         _element = element
191         _sceneState =
192             element.sceneStates[scene.key]
193                 ?: Element.SceneState(scene.key).also { element.sceneStates[scene.key] = it }
194     }
195 
addNodeToSceneStatenull196     private fun addNodeToSceneState() {
197         sceneState.nodes.add(this)
198 
199         coroutineScope.launch {
200             // At this point all [CodeLocationNode] have been attached or detached, which means that
201             // [sceneState.codeLocations] should have exactly 1 element, otherwise this means that
202             // this element was composed multiple times in the same scene.
203             val nCodeLocations = sceneState.nodes.size
204             if (nCodeLocations != 1 || !sceneState.nodes.contains(this@ElementNode)) {
205                 error("$key was composed $nCodeLocations times in ${sceneState.scene}")
206             }
207         }
208     }
209 
onDetachnull210     override fun onDetach() {
211         super.onDetach()
212         removeNodeFromSceneState()
213         maybePruneMaps(layoutImpl, element, sceneState)
214 
215         _element = null
216         _sceneState = null
217     }
218 
removeNodeFromSceneStatenull219     private fun removeNodeFromSceneState() {
220         sceneState.nodes.remove(this)
221     }
222 
updatenull223     fun update(
224         layoutImpl: SceneTransitionLayoutImpl,
225         currentTransitions: List<TransitionState.Transition>,
226         scene: Scene,
227         key: ElementKey,
228     ) {
229         check(layoutImpl == this.layoutImpl && scene == this.scene)
230         this.currentTransitions = currentTransitions
231 
232         removeNodeFromSceneState()
233 
234         val prevElement = this.element
235         val prevSceneState = this.sceneState
236         this.key = key
237         updateElementAndSceneValues()
238 
239         addNodeToSceneState()
240         maybePruneMaps(layoutImpl, prevElement, prevSceneState)
241     }
242 
isMeasurementApproachInProgressnull243     override fun isMeasurementApproachInProgress(lookaheadSize: IntSize): Boolean {
244         // TODO(b/324191441): Investigate whether making this check more complex (checking if this
245         // element is shared or transformed) would lead to better performance.
246         return layoutImpl.state.isTransitioning()
247     }
248 
isPlacementApproachInProgressnull249     override fun Placeable.PlacementScope.isPlacementApproachInProgress(
250         lookaheadCoordinates: LayoutCoordinates
251     ): Boolean {
252         // TODO(b/324191441): Investigate whether making this check more complex (checking if this
253         // element is shared or transformed) would lead to better performance.
254         return layoutImpl.state.isTransitioning()
255     }
256 
257     @ExperimentalComposeUiApi
measurenull258     override fun MeasureScope.measure(
259         measurable: Measurable,
260         constraints: Constraints
261     ): MeasureResult {
262         check(isLookingAhead)
263 
264         return measurable.measure(constraints).run {
265             // Update the size this element has in this scene when idle.
266             sceneState.targetSize = size()
267 
268             layout(width, height) {
269                 // Update the offset (relative to the SceneTransitionLayout) this element has in
270                 // this scene when idle.
271                 coordinates?.let { coords ->
272                     with(layoutImpl.lookaheadScope) {
273                         sceneState.targetOffset =
274                             lookaheadScopeCoordinates.localLookaheadPositionOf(coords)
275                     }
276                 }
277                 place(0, 0)
278             }
279         }
280     }
281 
approachMeasurenull282     override fun ApproachMeasureScope.approachMeasure(
283         measurable: Measurable,
284         constraints: Constraints,
285     ): MeasureResult {
286         val transitions = currentTransitions
287         val transition = elementTransition(layoutImpl, element, transitions)
288 
289         // If this element is not supposed to be laid out now, either because it is not part of any
290         // ongoing transition or the other scene of its transition is overscrolling, then lay out
291         // the element normally and don't place it.
292         val overscrollScene = transition?.currentOverscrollSpec?.scene
293         val isOtherSceneOverscrolling = overscrollScene != null && overscrollScene != scene.key
294         val isNotPartOfAnyOngoingTransitions = transitions.isNotEmpty() && transition == null
295         if (isNotPartOfAnyOngoingTransitions || isOtherSceneOverscrolling) {
296             recursivelyClearPlacementValues()
297             sceneState.lastSize = Element.SizeUnspecified
298 
299             val placeable = measurable.measure(constraints)
300             return layout(placeable.width, placeable.height) { /* Do not place */ }
301         }
302 
303         val placeable =
304             measure(layoutImpl, element, transition, sceneState, measurable, constraints)
305         sceneState.lastSize = placeable.size()
306         return layout(placeable.width, placeable.height) { place(transition, placeable) }
307     }
308 
309     @OptIn(ExperimentalComposeUiApi::class)
Placeablenull310     private fun Placeable.PlacementScope.place(
311         transition: TransitionState.Transition?,
312         placeable: Placeable,
313     ) {
314         with(layoutImpl.lookaheadScope) {
315             // Update the offset (relative to the SceneTransitionLayout) this element has in this
316             // scene when idle.
317             val coords =
318                 coordinates ?: error("Element ${element.key} does not have any coordinates")
319 
320             // No need to place the element in this scene if we don't want to draw it anyways.
321             if (!shouldPlaceElement(layoutImpl, scene.key, element, transition)) {
322                 recursivelyClearPlacementValues()
323                 return
324             }
325 
326             val currentOffset = lookaheadScopeCoordinates.localPositionOf(coords, Offset.Zero)
327             val targetOffset =
328                 computeValue(
329                     layoutImpl,
330                     sceneState,
331                     element,
332                     transition,
333                     sceneValue = { it.targetOffset },
334                     transformation = { it.offset },
335                     currentValue = { currentOffset },
336                     isSpecified = { it != Offset.Unspecified },
337                     ::lerp,
338                 )
339 
340             val interruptedOffset =
341                 computeInterruptedValue(
342                     layoutImpl,
343                     transition,
344                     value = targetOffset,
345                     unspecifiedValue = Offset.Unspecified,
346                     zeroValue = Offset.Zero,
347                     getValueBeforeInterruption = { sceneState.offsetBeforeInterruption },
348                     setValueBeforeInterruption = { sceneState.offsetBeforeInterruption = it },
349                     getInterruptionDelta = { sceneState.offsetInterruptionDelta },
350                     setInterruptionDelta = { delta ->
351                         setPlacementInterruptionDelta(
352                             element = element,
353                             sceneState = sceneState,
354                             transition = transition,
355                             delta = delta,
356                             setter = { sceneState, delta ->
357                                 sceneState.offsetInterruptionDelta = delta
358                             },
359                         )
360                     },
361                     diff = { a, b -> a - b },
362                     add = { a, b, bProgress -> a + b * bProgress },
363                 )
364 
365             sceneState.lastOffset = interruptedOffset
366 
367             val offset = (interruptedOffset - currentOffset).round()
368             if (
369                 isElementOpaque(scene, element, transition) &&
370                     interruptedAlpha(layoutImpl, element, transition, sceneState, alpha = 1f) == 1f
371             ) {
372                 sceneState.lastAlpha = 1f
373 
374                 // TODO(b/291071158): Call placeWithLayer() if offset != IntOffset.Zero and size is
375                 // not animated once b/305195729 is fixed. Test that drawing is not invalidated in
376                 // that case.
377                 placeable.place(offset)
378             } else {
379                 placeable.placeWithLayer(offset) {
380                     // This layer might still run on its own (outside of the placement phase) even
381                     // if this element is not placed or composed anymore, so we need to double check
382                     // again here before calling [elementAlpha] (which will update
383                     // [SceneState.lastAlpha]). We also need to recompute the current transition to
384                     // make sure that we are using the current transition and not a reference to an
385                     // old one. See b/343138966 for details.
386                     if (_element == null) {
387                         return@placeWithLayer
388                     }
389 
390                     val transition = elementTransition(layoutImpl, element, currentTransitions)
391                     if (!shouldPlaceElement(layoutImpl, scene.key, element, transition)) {
392                         return@placeWithLayer
393                     }
394 
395                     alpha = elementAlpha(layoutImpl, element, transition, sceneState)
396                     compositingStrategy = CompositingStrategy.ModulateAlpha
397                 }
398             }
399         }
400     }
401 
402     /**
403      * Recursively clear the last placement values on this node and all descendants ElementNodes.
404      * This should be called when this node is not placed anymore, so that we correctly clear values
405      * for the descendants for which approachMeasure() won't be called.
406      */
recursivelyClearPlacementValuesnull407     private fun recursivelyClearPlacementValues() {
408         fun Element.SceneState.clearLastPlacementValues() {
409             lastOffset = Offset.Unspecified
410             lastScale = Scale.Unspecified
411             lastAlpha = Element.AlphaUnspecified
412         }
413 
414         sceneState.clearLastPlacementValues()
415         traverseDescendants(ElementTraverseKey) { node ->
416             (node as ElementNode)._sceneState?.clearLastPlacementValues()
417             TraversableNode.Companion.TraverseDescendantsAction.ContinueTraversal
418         }
419     }
420 
drawnull421     override fun ContentDrawScope.draw() {
422         element.wasDrawnInAnyScene = true
423 
424         val transition = elementTransition(layoutImpl, element, currentTransitions)
425         val drawScale = getDrawScale(layoutImpl, element, transition, sceneState)
426         if (drawScale == Scale.Default) {
427             drawContent()
428         } else {
429             scale(
430                 drawScale.scaleX,
431                 drawScale.scaleY,
432                 if (drawScale.pivot.isUnspecified) center else drawScale.pivot,
433             ) {
434                 this@draw.drawContent()
435             }
436         }
437     }
438 
439     companion object {
440         private val ElementTraverseKey = Any()
441 
maybePruneMapsnull442         private fun maybePruneMaps(
443             layoutImpl: SceneTransitionLayoutImpl,
444             element: Element,
445             sceneState: Element.SceneState,
446         ) {
447             // If element is not composed from this scene anymore, remove the scene values. This
448             // works because [onAttach] is called before [onDetach], so if an element is moved from
449             // the UI tree we will first add the new code location then remove the old one.
450             if (sceneState.nodes.isEmpty() && element.sceneStates[sceneState.scene] == sceneState) {
451                 element.sceneStates.remove(sceneState.scene)
452 
453                 // If the element is not composed in any scene, remove it from the elements map.
454                 if (element.sceneStates.isEmpty() && layoutImpl.elements[element.key] == element) {
455                     layoutImpl.elements.remove(element.key)
456                 }
457             }
458         }
459     }
460 }
461 
462 /**
463  * The transition that we should consider for [element]. This is the last transition where one of
464  * its scenes contains the element.
465  */
elementTransitionnull466 private fun elementTransition(
467     layoutImpl: SceneTransitionLayoutImpl,
468     element: Element,
469     transitions: List<TransitionState.Transition>,
470 ): TransitionState.Transition? {
471     val transition =
472         transitions.fastLastOrNull { transition ->
473             transition.fromScene in element.sceneStates || transition.toScene in element.sceneStates
474         }
475 
476     val previousTransition = element.lastTransition
477     element.lastTransition = transition
478 
479     if (transition != previousTransition && transition != null && previousTransition != null) {
480         // The previous transition was interrupted by another transition.
481         prepareInterruption(layoutImpl, element, transition, previousTransition)
482     } else if (transition == null && previousTransition != null) {
483         // The transition was just finished.
484         element.sceneStates.values.forEach {
485             it.clearValuesBeforeInterruption()
486             it.clearInterruptionDeltas()
487         }
488     }
489 
490     return transition
491 }
492 
prepareInterruptionnull493 private fun prepareInterruption(
494     layoutImpl: SceneTransitionLayoutImpl,
495     element: Element,
496     transition: TransitionState.Transition,
497     previousTransition: TransitionState.Transition,
498 ) {
499     val sceneStates = element.sceneStates
500     fun updatedSceneState(key: SceneKey): Element.SceneState? {
501         return sceneStates[key]?.also { it.selfUpdateValuesBeforeInterruption() }
502     }
503 
504     val previousFromState = updatedSceneState(previousTransition.fromScene)
505     val previousToState = updatedSceneState(previousTransition.toScene)
506     val fromState = updatedSceneState(transition.fromScene)
507     val toState = updatedSceneState(transition.toScene)
508 
509     reconcileStates(element, previousTransition)
510     reconcileStates(element, transition)
511 
512     // Remove the interruption values to all scenes but the scene(s) where the element will be
513     // placed, to make sure that interruption deltas are computed only right after this interruption
514     // is prepared.
515     fun cleanInterruptionValues(sceneState: Element.SceneState) {
516         sceneState.sizeInterruptionDelta = IntSize.Zero
517         sceneState.offsetInterruptionDelta = Offset.Zero
518         sceneState.alphaInterruptionDelta = 0f
519         sceneState.scaleInterruptionDelta = Scale.Zero
520 
521         if (!shouldPlaceElement(layoutImpl, sceneState.scene, element, transition)) {
522             sceneState.offsetBeforeInterruption = Offset.Unspecified
523             sceneState.alphaBeforeInterruption = Element.AlphaUnspecified
524             sceneState.scaleBeforeInterruption = Scale.Unspecified
525         }
526     }
527 
528     previousFromState?.let { cleanInterruptionValues(it) }
529     previousToState?.let { cleanInterruptionValues(it) }
530     fromState?.let { cleanInterruptionValues(it) }
531     toState?.let { cleanInterruptionValues(it) }
532 }
533 
534 /**
535  * Reconcile the state of [element] in the fromScene and toScene of [transition] so that the values
536  * before interruption have their expected values, taking shared transitions into account.
537  */
reconcileStatesnull538 private fun reconcileStates(
539     element: Element,
540     transition: TransitionState.Transition,
541 ) {
542     val fromSceneState = element.sceneStates[transition.fromScene] ?: return
543     val toSceneState = element.sceneStates[transition.toScene] ?: return
544     if (!isSharedElementEnabled(element.key, transition)) {
545         return
546     }
547 
548     if (
549         fromSceneState.offsetBeforeInterruption != Offset.Unspecified &&
550             toSceneState.offsetBeforeInterruption == Offset.Unspecified
551     ) {
552         // Element is shared and placed in fromScene only.
553         toSceneState.updateValuesBeforeInterruption(fromSceneState)
554     } else if (
555         toSceneState.offsetBeforeInterruption != Offset.Unspecified &&
556             fromSceneState.offsetBeforeInterruption == Offset.Unspecified
557     ) {
558         // Element is shared and placed in toScene only.
559         fromSceneState.updateValuesBeforeInterruption(toSceneState)
560     }
561 }
562 
Elementnull563 private fun Element.SceneState.selfUpdateValuesBeforeInterruption() {
564     offsetBeforeInterruption = lastOffset
565     sizeBeforeInterruption = lastSize
566     scaleBeforeInterruption = lastScale
567     alphaBeforeInterruption = lastAlpha
568 }
569 
updateValuesBeforeInterruptionnull570 private fun Element.SceneState.updateValuesBeforeInterruption(lastState: Element.SceneState) {
571     offsetBeforeInterruption = lastState.offsetBeforeInterruption
572     sizeBeforeInterruption = lastState.sizeBeforeInterruption
573     scaleBeforeInterruption = lastState.scaleBeforeInterruption
574     alphaBeforeInterruption = lastState.alphaBeforeInterruption
575 
576     clearInterruptionDeltas()
577 }
578 
Elementnull579 private fun Element.SceneState.clearInterruptionDeltas() {
580     offsetInterruptionDelta = Offset.Zero
581     sizeInterruptionDelta = IntSize.Zero
582     scaleInterruptionDelta = Scale.Zero
583     alphaInterruptionDelta = 0f
584 }
585 
clearValuesBeforeInterruptionnull586 private fun Element.SceneState.clearValuesBeforeInterruption() {
587     offsetBeforeInterruption = Offset.Unspecified
588     scaleBeforeInterruption = Scale.Unspecified
589     alphaBeforeInterruption = Element.AlphaUnspecified
590 }
591 
592 /**
593  * Compute what [value] should be if we take the
594  * [interruption progress][TransitionState.Transition.interruptionProgress] of [transition] into
595  * account.
596  */
computeInterruptedValuenull597 private inline fun <T> computeInterruptedValue(
598     layoutImpl: SceneTransitionLayoutImpl,
599     transition: TransitionState.Transition?,
600     value: T,
601     unspecifiedValue: T,
602     zeroValue: T,
603     getValueBeforeInterruption: () -> T,
604     setValueBeforeInterruption: (T) -> Unit,
605     getInterruptionDelta: () -> T,
606     setInterruptionDelta: (T) -> Unit,
607     diff: (a: T, b: T) -> T, // a - b
608     add: (a: T, b: T, bProgress: Float) -> T, // a + (b * bProgress)
609 ): T {
610     val valueBeforeInterruption = getValueBeforeInterruption()
611 
612     // If the value before the interruption is specified, it means that this is the first time we
613     // compute [value] right after an interruption.
614     if (valueBeforeInterruption != unspecifiedValue) {
615         // Compute and store the delta between the value before the interruption and the current
616         // value.
617         setInterruptionDelta(diff(valueBeforeInterruption, value))
618 
619         // Reset the value before interruption now that we processed it.
620         setValueBeforeInterruption(unspecifiedValue)
621     }
622 
623     val delta = getInterruptionDelta()
624     return if (delta == zeroValue || transition == null) {
625         // There was no interruption or there is no transition: just return the value.
626         value
627     } else {
628         // Add `delta * interruptionProgress` to the value so that we animate to value.
629         val interruptionProgress = transition.interruptionProgress(layoutImpl)
630         if (interruptionProgress == 0f) {
631             value
632         } else {
633             add(value, delta, interruptionProgress)
634         }
635     }
636 }
637 
638 /**
639  * Set the interruption delta of a *placement/drawing*-related value (offset, alpha, scale). This
640  * ensures that the delta is also set on the other scene in the transition for shared elements, so
641  * that there is no jump cut if the scene where the element is placed has changed.
642  */
setPlacementInterruptionDeltanull643 private inline fun <T> setPlacementInterruptionDelta(
644     element: Element,
645     sceneState: Element.SceneState,
646     transition: TransitionState.Transition?,
647     delta: T,
648     setter: (Element.SceneState, T) -> Unit,
649 ) {
650     // Set the interruption delta on the current scene.
651     setter(sceneState, delta)
652 
653     if (transition == null) {
654         return
655     }
656 
657     // If the element is shared, also set the delta on the other scene so that it is used by that
658     // scene if we start overscrolling it and change the scene where the element is placed.
659     val otherScene =
660         if (sceneState.scene == transition.fromScene) transition.toScene else transition.fromScene
661     val otherSceneState = element.sceneStates[otherScene] ?: return
662     if (isSharedElementEnabled(element.key, transition)) {
663         setter(otherSceneState, delta)
664     }
665 }
666 
shouldPlaceElementnull667 private fun shouldPlaceElement(
668     layoutImpl: SceneTransitionLayoutImpl,
669     scene: SceneKey,
670     element: Element,
671     transition: TransitionState.Transition?,
672 ): Boolean {
673     // Always place the element if we are idle.
674     if (transition == null) {
675         return true
676     }
677 
678     // Don't place the element in this scene if this scene is not part of the current element
679     // transition.
680     if (scene != transition.fromScene && scene != transition.toScene) {
681         return false
682     }
683 
684     // Place the element if it is not shared.
685     if (
686         transition.fromScene !in element.sceneStates || transition.toScene !in element.sceneStates
687     ) {
688         return true
689     }
690 
691     val sharedTransformation = sharedElementTransformation(element.key, transition)
692     if (sharedTransformation?.enabled == false) {
693         return true
694     }
695 
696     return shouldPlaceOrComposeSharedElement(
697         layoutImpl,
698         scene,
699         element.key,
700         transition,
701     )
702 }
703 
shouldPlaceOrComposeSharedElementnull704 internal fun shouldPlaceOrComposeSharedElement(
705     layoutImpl: SceneTransitionLayoutImpl,
706     scene: SceneKey,
707     element: ElementKey,
708     transition: TransitionState.Transition,
709 ): Boolean {
710     // If we are overscrolling, only place/compose the element in the overscrolling scene.
711     val overscrollScene = transition.currentOverscrollSpec?.scene
712     if (overscrollScene != null) {
713         return scene == overscrollScene
714     }
715 
716     val scenePicker = element.scenePicker
717     val fromScene = transition.fromScene
718     val toScene = transition.toScene
719 
720     val pickedScene =
721         scenePicker.sceneDuringTransition(
722             element = element,
723             transition = transition,
724             fromSceneZIndex = layoutImpl.scenes.getValue(fromScene).zIndex,
725             toSceneZIndex = layoutImpl.scenes.getValue(toScene).zIndex,
726         ) ?: return false
727 
728     return pickedScene == scene
729 }
730 
isSharedElementEnablednull731 private fun isSharedElementEnabled(
732     element: ElementKey,
733     transition: TransitionState.Transition,
734 ): Boolean {
735     return sharedElementTransformation(element, transition)?.enabled ?: true
736 }
737 
sharedElementTransformationnull738 internal fun sharedElementTransformation(
739     element: ElementKey,
740     transition: TransitionState.Transition,
741 ): SharedElementTransformation? {
742     val transformationSpec = transition.transformationSpec
743     val sharedInFromScene = transformationSpec.transformations(element, transition.fromScene).shared
744     val sharedInToScene = transformationSpec.transformations(element, transition.toScene).shared
745 
746     // The sharedElement() transformation must either be null or be the same in both scenes.
747     if (sharedInFromScene != sharedInToScene) {
748         error(
749             "Different sharedElement() transformations matched $element (from=$sharedInFromScene " +
750                 "to=$sharedInToScene)"
751         )
752     }
753 
754     return sharedInFromScene
755 }
756 
757 /**
758  * Whether the element is opaque or not.
759  *
760  * Important: The logic here should closely match the logic in [elementAlpha]. Note that we don't
761  * reuse [elementAlpha] and simply check if alpha == 1f because [isElementOpaque] is checked during
762  * placement and we don't want to read the transition progress in that phase.
763  */
isElementOpaquenull764 private fun isElementOpaque(
765     scene: Scene,
766     element: Element,
767     transition: TransitionState.Transition?,
768 ): Boolean {
769     if (transition == null) {
770         return true
771     }
772 
773     val fromScene = transition.fromScene
774     val toScene = transition.toScene
775     val fromState = element.sceneStates[fromScene]
776     val toState = element.sceneStates[toScene]
777 
778     if (fromState == null && toState == null) {
779         // TODO(b/311600838): Throw an exception instead once layers of disposed elements are not
780         // run anymore.
781         return true
782     }
783 
784     val isSharedElement = fromState != null && toState != null
785     if (isSharedElement && isSharedElementEnabled(element.key, transition)) {
786         return true
787     }
788 
789     return transition.transformationSpec.transformations(element.key, scene.key).alpha == null
790 }
791 
792 /**
793  * Whether the element is opaque or not.
794  *
795  * Important: The logic here should closely match the logic in [isElementOpaque]. Note that we don't
796  * reuse [elementAlpha] in [isElementOpaque] and simply check if alpha == 1f because
797  * [isElementOpaque] is checked during placement and we don't want to read the transition progress
798  * in that phase.
799  */
elementAlphanull800 private fun elementAlpha(
801     layoutImpl: SceneTransitionLayoutImpl,
802     element: Element,
803     transition: TransitionState.Transition?,
804     sceneState: Element.SceneState,
805 ): Float {
806     val alpha =
807         computeValue(
808                 layoutImpl,
809                 sceneState,
810                 element,
811                 transition,
812                 sceneValue = { 1f },
813                 transformation = { it.alpha },
814                 currentValue = { 1f },
815                 isSpecified = { true },
816                 ::lerp,
817             )
818             .fastCoerceIn(0f, 1f)
819 
820     // If the element is fading during this transition and that it is drawn for the first time, make
821     // sure that it doesn't instantly appear on screen.
822     if (!element.wasDrawnInAnyScene && alpha > 0f) {
823         element.sceneStates.forEach { it.value.alphaBeforeInterruption = 0f }
824     }
825 
826     val interruptedAlpha = interruptedAlpha(layoutImpl, element, transition, sceneState, alpha)
827     sceneState.lastAlpha = interruptedAlpha
828     return interruptedAlpha
829 }
830 
interruptedAlphanull831 private fun interruptedAlpha(
832     layoutImpl: SceneTransitionLayoutImpl,
833     element: Element,
834     transition: TransitionState.Transition?,
835     sceneState: Element.SceneState,
836     alpha: Float,
837 ): Float {
838     return computeInterruptedValue(
839         layoutImpl,
840         transition,
841         value = alpha,
842         unspecifiedValue = Element.AlphaUnspecified,
843         zeroValue = 0f,
844         getValueBeforeInterruption = { sceneState.alphaBeforeInterruption },
845         setValueBeforeInterruption = { sceneState.alphaBeforeInterruption = it },
846         getInterruptionDelta = { sceneState.alphaInterruptionDelta },
847         setInterruptionDelta = { delta ->
848             setPlacementInterruptionDelta(
849                 element = element,
850                 sceneState = sceneState,
851                 transition = transition,
852                 delta = delta,
853                 setter = { sceneState, delta -> sceneState.alphaInterruptionDelta = delta },
854             )
855         },
856         diff = { a, b -> a - b },
857         add = { a, b, bProgress -> a + b * bProgress },
858     )
859 }
860 
measurenull861 private fun measure(
862     layoutImpl: SceneTransitionLayoutImpl,
863     element: Element,
864     transition: TransitionState.Transition?,
865     sceneState: Element.SceneState,
866     measurable: Measurable,
867     constraints: Constraints,
868 ): Placeable {
869     // Some lambdas called (max once) by computeValue() will need to measure [measurable], in which
870     // case we store the resulting placeable here to make sure the element is not measured more than
871     // once.
872     var maybePlaceable: Placeable? = null
873 
874     val targetSize =
875         computeValue(
876             layoutImpl,
877             sceneState,
878             element,
879             transition,
880             sceneValue = { it.targetSize },
881             transformation = { it.size },
882             currentValue = { measurable.measure(constraints).also { maybePlaceable = it }.size() },
883             isSpecified = { it != Element.SizeUnspecified },
884             ::lerp,
885         )
886 
887     // The measurable was already measured, so we can't take interruptions into account here given
888     // that we are not allowed to measure the same measurable twice.
889     maybePlaceable?.let { placeable ->
890         sceneState.sizeBeforeInterruption = Element.SizeUnspecified
891         sceneState.sizeInterruptionDelta = IntSize.Zero
892         return placeable
893     }
894 
895     val interruptedSize =
896         computeInterruptedValue(
897             layoutImpl,
898             transition,
899             value = targetSize,
900             unspecifiedValue = Element.SizeUnspecified,
901             zeroValue = IntSize.Zero,
902             getValueBeforeInterruption = { sceneState.sizeBeforeInterruption },
903             setValueBeforeInterruption = { sceneState.sizeBeforeInterruption = it },
904             getInterruptionDelta = { sceneState.sizeInterruptionDelta },
905             setInterruptionDelta = { sceneState.sizeInterruptionDelta = it },
906             diff = { a, b -> IntSize(a.width - b.width, a.height - b.height) },
907             add = { a, b, bProgress ->
908                 IntSize(
909                     (a.width + b.width * bProgress).roundToInt(),
910                     (a.height + b.height * bProgress).roundToInt(),
911                 )
912             },
913         )
914 
915     return measurable.measure(
916         Constraints.fixed(
917             interruptedSize.width.coerceAtLeast(0),
918             interruptedSize.height.coerceAtLeast(0),
919         )
920     )
921 }
922 
sizenull923 private fun Placeable.size(): IntSize = IntSize(width, height)
924 
925 private fun ContentDrawScope.getDrawScale(
926     layoutImpl: SceneTransitionLayoutImpl,
927     element: Element,
928     transition: TransitionState.Transition?,
929     sceneState: Element.SceneState,
930 ): Scale {
931     val scale =
932         computeValue(
933             layoutImpl,
934             sceneState,
935             element,
936             transition,
937             sceneValue = { Scale.Default },
938             transformation = { it.drawScale },
939             currentValue = { Scale.Default },
940             isSpecified = { true },
941             ::lerp,
942         )
943 
944     fun Offset.specifiedOrCenter(): Offset {
945         return this.takeIf { isSpecified } ?: center
946     }
947 
948     val interruptedScale =
949         computeInterruptedValue(
950             layoutImpl,
951             transition,
952             value = scale,
953             unspecifiedValue = Scale.Unspecified,
954             zeroValue = Scale.Zero,
955             getValueBeforeInterruption = { sceneState.scaleBeforeInterruption },
956             setValueBeforeInterruption = { sceneState.scaleBeforeInterruption = it },
957             getInterruptionDelta = { sceneState.scaleInterruptionDelta },
958             setInterruptionDelta = { delta ->
959                 setPlacementInterruptionDelta(
960                     element = element,
961                     sceneState = sceneState,
962                     transition = transition,
963                     delta = delta,
964                     setter = { sceneState, delta -> sceneState.scaleInterruptionDelta = delta },
965                 )
966             },
967             diff = { a, b ->
968                 Scale(
969                     scaleX = a.scaleX - b.scaleX,
970                     scaleY = a.scaleY - b.scaleY,
971                     pivot =
972                         if (a.pivot.isUnspecified && b.pivot.isUnspecified) {
973                             Offset.Unspecified
974                         } else {
975                             a.pivot.specifiedOrCenter() - b.pivot.specifiedOrCenter()
976                         }
977                 )
978             },
979             add = { a, b, bProgress ->
980                 Scale(
981                     scaleX = a.scaleX + b.scaleX * bProgress,
982                     scaleY = a.scaleY + b.scaleY * bProgress,
983                     pivot =
984                         if (a.pivot.isUnspecified && b.pivot.isUnspecified) {
985                             Offset.Unspecified
986                         } else {
987                             a.pivot.specifiedOrCenter() + b.pivot.specifiedOrCenter() * bProgress
988                         }
989                 )
990             }
991         )
992 
993     sceneState.lastScale = interruptedScale
994     return interruptedScale
995 }
996 
997 /**
998  * Return the value that should be used depending on the current layout state and transition.
999  *
1000  * Important: This function must remain inline because of all the lambda parameters. These lambdas
1001  * are necessary because getting some of them might require some computation, like measuring a
1002  * Measurable.
1003  *
1004  * @param layoutImpl the [SceneTransitionLayoutImpl] associated to [element].
1005  * @param currentSceneState the scene state of the scene for which we are computing the value. Note
1006  *   that during interruptions, this could be the state of a scene that is neither
1007  *   [transition.toScene] nor [transition.fromScene].
1008  * @param element the element being animated.
1009  * @param sceneValue the value being animated.
1010  * @param transformation the transformation associated to the value being animated.
1011  * @param currentValue the value that would be used if it is not transformed. Note that this is
1012  *   different than [idleValue] even if the value is not transformed directly because it could be
1013  *   impacted by the transformations on other elements, like a parent that is being translated or
1014  *   resized.
1015  * @param lerp the linear interpolation function used to interpolate between two values of this
1016  *   value type.
1017  */
computeValuenull1018 private inline fun <T> computeValue(
1019     layoutImpl: SceneTransitionLayoutImpl,
1020     currentSceneState: Element.SceneState,
1021     element: Element,
1022     transition: TransitionState.Transition?,
1023     sceneValue: (Element.SceneState) -> T,
1024     transformation: (ElementTransformations) -> PropertyTransformation<T>?,
1025     currentValue: () -> T,
1026     isSpecified: (T) -> Boolean,
1027     lerp: (T, T, Float) -> T,
1028 ): T {
1029     if (transition == null) {
1030         // There is no ongoing transition. Even if this element SceneTransitionLayout is not
1031         // animated, the layout itself might be animated (e.g. by another parent
1032         // SceneTransitionLayout), in which case this element still need to participate in the
1033         // layout phase.
1034         return currentValue()
1035     }
1036 
1037     val fromScene = transition.fromScene
1038     val toScene = transition.toScene
1039 
1040     val fromState = element.sceneStates[fromScene]
1041     val toState = element.sceneStates[toScene]
1042 
1043     if (fromState == null && toState == null) {
1044         // TODO(b/311600838): Throw an exception instead once layers of disposed elements are not
1045         // run anymore.
1046         return sceneValue(currentSceneState)
1047     }
1048 
1049     val currentScene = currentSceneState.scene
1050     if (transition is TransitionState.HasOverscrollProperties) {
1051         val overscroll = transition.currentOverscrollSpec
1052         if (overscroll?.scene == currentScene) {
1053             val elementSpec =
1054                 overscroll.transformationSpec.transformations(element.key, currentScene)
1055             val propertySpec = transformation(elementSpec) ?: return currentValue()
1056             val overscrollState = checkNotNull(if (currentScene == toScene) toState else fromState)
1057             val idleValue = sceneValue(overscrollState)
1058             val targetValue =
1059                 propertySpec.transform(
1060                     layoutImpl,
1061                     currentScene,
1062                     element,
1063                     overscrollState,
1064                     transition,
1065                     idleValue,
1066                 )
1067 
1068             // Make sure we don't read progress if values are the same and we don't need to
1069             // interpolate, so we don't invalidate the phase where this is read.
1070             if (targetValue == idleValue) {
1071                 return targetValue
1072             }
1073 
1074             // TODO(b/290184746): Make sure that we don't overflow transformations associated to a
1075             // range.
1076             val directionSign = if (transition.isUpOrLeft) -1 else 1
1077             val isToScene = overscroll.scene == transition.toScene
1078             val overscrollProgress = transition.progress.let { if (isToScene) it - 1f else it }
1079             val progress = directionSign * overscrollProgress
1080             val rangeProgress = propertySpec.range?.progress(progress) ?: progress
1081 
1082             // Interpolate between the value at rest and the over scrolled value.
1083             return lerp(idleValue, targetValue, rangeProgress)
1084         }
1085     }
1086 
1087     // The element is shared: interpolate between the value in fromScene and the value in toScene.
1088     // TODO(b/290184746): Support non linear shared paths as well as a way to make sure that shared
1089     // elements follow the finger direction.
1090     val isSharedElement = fromState != null && toState != null
1091     if (isSharedElement && isSharedElementEnabled(element.key, transition)) {
1092         val start = sceneValue(fromState!!)
1093         val end = sceneValue(toState!!)
1094 
1095         // TODO(b/316901148): Remove checks to isSpecified() once the lookahead pass runs for all
1096         // nodes before the intermediate layout pass.
1097         if (!isSpecified(start)) return end
1098         if (!isSpecified(end)) return start
1099 
1100         // Make sure we don't read progress if values are the same and we don't need to interpolate,
1101         // so we don't invalidate the phase where this is read.
1102         return if (start == end) start else lerp(start, end, transition.progress)
1103     }
1104 
1105     // Get the transformed value, i.e. the target value at the beginning (for entering elements) or
1106     // end (for leaving elements) of the transition.
1107     val sceneState =
1108         checkNotNull(
1109             when {
1110                 isSharedElement && currentScene == fromScene -> fromState
1111                 isSharedElement -> toState
1112                 else -> fromState ?: toState
1113             }
1114         )
1115 
1116     // The scene for which we compute the transformation. Note that this is not necessarily
1117     // [currentScene] because [currentScene] could be a different scene than the transition
1118     // fromScene or toScene during interruptions.
1119     val scene = sceneState.scene
1120 
1121     val transformation =
1122         transformation(transition.transformationSpec.transformations(element.key, scene))
1123             // If there is no transformation explicitly associated to this element value, let's use
1124             // the value given by the system (like the current position and size given by the layout
1125             // pass).
1126             ?: return currentValue()
1127 
1128     val idleValue = sceneValue(sceneState)
1129     val targetValue =
1130         transformation.transform(
1131             layoutImpl,
1132             scene,
1133             element,
1134             sceneState,
1135             transition,
1136             idleValue,
1137         )
1138 
1139     // Make sure we don't read progress if values are the same and we don't need to interpolate, so
1140     // we don't invalidate the phase where this is read.
1141     if (targetValue == idleValue) {
1142         return targetValue
1143     }
1144 
1145     val progress = transition.progress
1146     // TODO(b/290184746): Make sure that we don't overflow transformations associated to a range.
1147     val rangeProgress = transformation.range?.progress(progress) ?: progress
1148 
1149     // Interpolate between the value at rest and the value before entering/after leaving.
1150     val isEntering = scene == toScene
1151     return if (isEntering) {
1152         lerp(targetValue, idleValue, rangeProgress)
1153     } else {
1154         lerp(idleValue, targetValue, rangeProgress)
1155     }
1156 }
1157