001/**
002 * Copyright (c) 2025-2026, Michael Yang 杨福海 (fuhai999@gmail.com).
003 * <p>
004 * Licensed under the GNU Lesser General Public License (LGPL) ,Version 3.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 * <p>
008 * http://www.gnu.org/licenses/lgpl-3.0.txt
009 * <p>
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package dev.tinyflow.core.node;
017
018
019import dev.tinyflow.core.chain.*;
020import dev.tinyflow.core.chain.repository.ChainStateField;
021import dev.tinyflow.core.chain.repository.NodeStateField;
022import dev.tinyflow.core.chain.runtime.Trigger;
023import dev.tinyflow.core.chain.runtime.TriggerContext;
024import dev.tinyflow.core.chain.runtime.TriggerType;
025import dev.tinyflow.core.util.IterableUtil;
026import dev.tinyflow.core.util.Maps;
027import dev.tinyflow.core.util.StringUtil;
028
029import java.io.Serializable;
030import java.util.*;
031import java.util.concurrent.ConcurrentHashMap;
032
033public class LoopNode extends BaseNode {
034
035    private Parameter loopVar;
036
037    public Parameter getLoopVar() {
038        return loopVar;
039    }
040
041    public void setLoopVar(Parameter loopVar) {
042        this.loopVar = loopVar;
043    }
044
045    @Override
046    public Map<String, Object> execute(Chain chain) {
047        Trigger prevTrigger = TriggerContext.getCurrentTrigger();
048        Deque<LoopContext> loopStack = getOrCreateLoopStack(chain);
049
050        LoopContext loopContext;
051
052        // 判断是否是首次进入该 LoopNode(即不是由子节点返回)
053        TriggerType triggerType = prevTrigger.getType();
054        boolean isFirstEntry = triggerType != TriggerType.PARENT && triggerType != TriggerType.SELF;
055
056        if (isFirstEntry) {
057            // 首次触发:创建新的 LoopContext 并压入堆栈
058            loopContext = new LoopContext();
059            loopContext.currentIndex = 0;
060            loopContext.subResult = new HashMap<>();
061            // 保存原始触发上下文(用于循环结束后恢复)
062            loopStack.offerLast(loopContext);
063
064            chain.updateNodeStateSafely(this.id, state -> {
065                state.getMemory().put(buildLoopStackId(), loopStack);
066                return EnumSet.of(NodeStateField.MEMORY);
067            });
068
069            if (loopStack.size() > 1) {
070                // 不执行,等等其他节点唤起
071                return Maps.of(ChainConsts.SCHEDULE_NEXT_NODE_DISABLED_KEY, true)
072                        .set(ChainConsts.NODE_STATE_STATUS_KEY, NodeStatus.RUNNING);
073            }
074        }
075        // 由子节点返回:从堆栈低部获取当前循环上下文
076        else {
077            if (loopStack.isEmpty()) {
078                throw new IllegalStateException("Loop stack is empty when returning from child node.");
079            }
080            loopContext = loopStack.peekFirst();
081        }
082
083
084//       LoopContext loopContext = getLoopContext(prevTrigger, chain);
085//        int triggerLoopIndex = getTriggerLoopIndex(prevTrigger);
086//
087//        if (loopContext.currentIndex != triggerLoopIndex) {
088//            // 不执行,子流程有分叉,已经被其他的分叉节点触发了
089//            return Maps.of(ChainConsts.SCHEDULE_NEXT_NODE_DISABLED_KEY, true)
090//                    .set(ChainConsts.NODE_STATE_STATUS_KEY, NodeStatus.RUNNING);
091//        }
092
093        Map<String, Object> loopVars = chain.getState().resolveParameters(this, Collections.singletonList(loopVar));
094        Object loopValue = loopVars.get(loopVar.getName());
095
096        int shouldLoopCount;
097        if (loopValue instanceof Iterable) {
098            shouldLoopCount = IterableUtil.size((Iterable<?>) loopValue);
099        } else if (loopValue instanceof Number || (loopValue instanceof String && StringUtil.isNumeric(loopValue.toString()))) {
100            shouldLoopCount = loopValue instanceof Number ? ((Number) loopValue).intValue() : Integer.parseInt(loopValue.toString().trim());
101        } else {
102            throw new IllegalArgumentException("loopValue must be Iterable or Number or String, but loopValue is \"" + loopValue + "\"");
103        }
104
105        //  不是第一次执行,合并结果到 subResult
106        if (loopContext.currentIndex != 0) {
107            ChainState subState = chain.getState();
108            mergeResult(loopContext.subResult, subState);
109        }
110
111
112        // 执行的次数够了, 恢复父级触发
113        if (loopContext.currentIndex >= shouldLoopCount) {
114            loopStack.pollFirst();    // 移除最顶部部的 LoopContext
115            chain.updateNodeStateSafely(this.id, state -> {
116                ConcurrentHashMap<String, Object> memory = state.getMemory();
117                memory.put(buildLoopStackId(), loopStack);
118                memory.remove(this.id + ".index");
119                memory.remove(this.id + ".loopItem");
120                return EnumSet.of(NodeStateField.MEMORY);
121            });
122            if (!loopStack.isEmpty()) {
123                chain.scheduleNode(this, null, TriggerType.SELF, 0);
124            }
125            return loopContext.subResult;
126        }
127
128        int loopIndex = loopContext.currentIndex;
129        loopContext.currentIndex++;
130
131        chain.updateNodeStateSafely(this.id, state -> {
132            state.getMemory().put(buildLoopStackId(), loopStack);
133            return EnumSet.of(NodeStateField.MEMORY);
134        });
135
136
137        if (loopValue instanceof Iterable) {
138            Object loopItem = IterableUtil.get((Iterable<?>) loopValue, loopIndex);
139            executeLoopChain(chain, loopContext, loopItem);
140        } else if (loopValue instanceof Number || (loopValue instanceof String && StringUtil.isNumeric(loopValue.toString()))) {
141            executeLoopChain(chain, loopContext, loopIndex);
142        } else {
143            throw new IllegalArgumentException("loopValue must be Iterable or Number or String, but loopValue is \"" + loopValue + "\"");
144        }
145
146        // 禁用调度下个节点
147        return Maps.of(ChainConsts.SCHEDULE_NEXT_NODE_DISABLED_KEY, true)
148                .set(ChainConsts.NODE_STATE_STATUS_KEY, NodeStatus.RUNNING);
149    }
150
151
152    /**
153     * 获取或创建当前节点的 LoopContext 堆栈(每个 LoopNode 实例独立)
154     */
155    @SuppressWarnings("unchecked")
156    private Deque<LoopContext> getOrCreateLoopStack(Chain chain) {
157        NodeState nodeState = chain.getNodeState(this.id);
158        String key = buildLoopStackId();
159        Object stackObj = nodeState.getMemory().get(key);
160        Deque<LoopContext> stack;
161        if (stackObj instanceof Deque) {
162            stack = (Deque<LoopContext>) stackObj;
163        } else {
164            stack = new ArrayDeque<>();
165            chain.updateNodeStateSafely(this.id, state -> {
166                state.getMemory().put(key, stack);
167                return EnumSet.of(NodeStateField.MEMORY);
168            });
169        }
170        return stack;
171    }
172
173
174    private void executeLoopChain(Chain chain, LoopContext loopContext, Object loopItem) {
175
176        chain.updateStateSafely(state -> {
177            ConcurrentHashMap<String, Object> memory = state.getMemory();
178            memory.put(this.id + ".index", (loopContext.currentIndex - 1));
179            memory.put(this.id + ".loopItem", loopItem);
180            return EnumSet.of(ChainStateField.MEMORY);
181        });
182
183
184        ChainDefinition definition = chain.getDefinition();
185        List<Edge> outwardEdges = definition.getOutwardEdge(this.id);
186        for (Edge edge : outwardEdges) {
187            Node childNode = definition.getNodeById(edge.getTarget());
188            if (childNode.getParentId() != null && childNode.getParentId().equals(this.id)) {
189                chain.scheduleNode(childNode, edge.getId(), TriggerType.CHILD, 0);
190            }
191        }
192    }
193
194
195    /**
196     * 把子流程执行的结果填充到主流程的输出参数中
197     *
198     * @param toResult 主流程的输出参数
199     * @param subState 子流程的
200     */
201    private void mergeResult(Map<String, Object> toResult, ChainState subState) {
202        List<Parameter> outputDefs = getOutputDefs();
203        if (outputDefs != null) {
204            for (Parameter outputDef : outputDefs) {
205                Object value = null;
206
207                //引用
208                if (outputDef.getRefType() == RefType.REF) {
209                    value = subState.resolveValue(outputDef.getRef());
210                }
211                //固定值
212                else if (outputDef.getRefType() == RefType.FIXED) {
213                    value = outputDef.getValue();
214                }
215
216                @SuppressWarnings("unchecked") List<Object> existList = (List<Object>) toResult.get(outputDef.getName());
217                if (existList == null) {
218                    existList = new ArrayList<>();
219                }
220                existList.add(value);
221                toResult.put(outputDef.getName(), existList);
222            }
223        }
224    }
225
226
227    private String buildLoopStackId() {
228        return this.getId() + "__loop__context";
229    }
230
231
232    public static class LoopContext implements Serializable {
233        int currentIndex;
234        Map<String, Object> subResult;
235
236        public int getCurrentIndex() {
237            return currentIndex;
238        }
239
240        public void setCurrentIndex(int currentIndex) {
241            this.currentIndex = currentIndex;
242        }
243
244        public Map<String, Object> getSubResult() {
245            return subResult;
246        }
247
248        public void setSubResult(Map<String, Object> subResult) {
249            this.subResult = subResult;
250        }
251
252    }
253}