001/**
002 * Copyright (c) 2025-2026, Michael Yang 杨福海 (fuhai999@gmail.com).
003 * <p>
004 * Licensed under the GNU Lesser General Public License (LGPL) ,Version 3.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 * <p>
008 * http://www.gnu.org/licenses/lgpl-3.0.txt
009 * <p>
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package dev.tinyflow.core.node;
017
018import com.alibaba.fastjson.JSON;
019import dev.tinyflow.core.chain.Chain;
020import dev.tinyflow.core.chain.Parameter;
021import dev.tinyflow.core.llm.Llm;
022import dev.tinyflow.core.llm.LlmManager;
023import dev.tinyflow.core.util.*;
024
025import java.io.File;
026import java.util.*;
027
028public class LlmNode extends BaseNode {
029
030    protected String llmId;
031    protected Llm.ChatOptions chatOptions;
032    protected String userPrompt;
033    protected String systemPrompt;
034    protected String jsonSchema;
035    protected String outType = "text"; //text markdown json
036    protected List<Parameter> images;
037
038    public LlmNode() {
039    }
040
041    public String getLlmId() {
042        return llmId;
043    }
044
045    public void setLlmId(String llmId) {
046        this.llmId = llmId;
047    }
048
049    public String getUserPrompt() {
050        return userPrompt;
051    }
052
053    public void setUserPrompt(String userPrompt) {
054        this.userPrompt = userPrompt;
055    }
056
057    public String getSystemPrompt() {
058        return systemPrompt;
059    }
060
061    public void setSystemPrompt(String systemPrompt) {
062        this.systemPrompt = systemPrompt;
063    }
064
065    public String getJsonSchema() {
066        return jsonSchema;
067    }
068
069    public void setJsonSchema(String jsonSchema) {
070        this.jsonSchema = jsonSchema;
071    }
072
073    public Llm.ChatOptions getChatOptions() {
074        return chatOptions;
075    }
076
077    public void setChatOptions(Llm.ChatOptions chatOptions) {
078        this.chatOptions = chatOptions;
079    }
080
081    public String getOutType() {
082        return outType;
083    }
084
085    public void setOutType(String outType) {
086        this.outType = outType;
087    }
088
089    public List<Parameter> getImages() {
090        return images;
091    }
092
093    public void setImages(List<Parameter> images) {
094        this.images = images;
095    }
096
097    @Override
098    public Map<String, Object> execute(Chain chain) {
099        Map<String, Object> formatParameters = getFormatParameters(chain);
100
101        if (StringUtil.noText(userPrompt)) {
102            throw new RuntimeException("Can not find user prompt");
103        }
104
105        String userPromptString = TextTemplate.of(userPrompt).formatToString(formatParameters);
106
107
108        Llm llm = LlmManager.getInstance().getChatModel(this.llmId);
109        if (llm == null) {
110            throw new RuntimeException("Can not find llm: " + this.llmId);
111        }
112
113        String systemPromptString = TextTemplate.of(this.systemPrompt).formatToString(formatParameters);
114
115        Llm.MessageInfo messageInfo = new Llm.MessageInfo();
116        messageInfo.setMessage(userPromptString);
117        messageInfo.setSystemMessage(systemPromptString);
118
119        if (images != null && !images.isEmpty()) {
120            Map<String, Object> filesMap = chain.getState().resolveParameters(this, images);
121            List<String> imagesUrls = new ArrayList<>();
122            filesMap.forEach((s, o) -> {
123                if (o instanceof String) {
124                    imagesUrls.add((String) o);
125                } else if (o instanceof File) {
126                    byte[] bytes = IOUtil.readBytes((File) o);
127                    String base64 = Base64.getEncoder().encodeToString(bytes);
128                    imagesUrls.add(base64);
129                }
130            });
131            messageInfo.setImages(imagesUrls);
132        }
133
134
135        String responseContent = llm.chat(messageInfo, chatOptions, this, chain);
136
137        if (StringUtil.noText(responseContent)) {
138            throw new RuntimeException("Can not get response from llm");
139        } else {
140            responseContent = responseContent.trim();
141        }
142
143
144        if ("json".equalsIgnoreCase(outType)) {
145            Object jsonObjectOrArray;
146            try {
147                jsonObjectOrArray = JSON.parse(unWrapMarkdown(responseContent));
148            } catch (Exception e) {
149                throw new RuntimeException("Can not parse json: " + responseContent + " " + e.getMessage());
150            }
151
152            if (CollectionUtil.noItems(this.outputDefs)) {
153                return Maps.of("root", jsonObjectOrArray);
154            } else {
155                Parameter parameter = this.outputDefs.get(0);
156                return Maps.of(parameter.getName(), jsonObjectOrArray);
157            }
158        } else {
159            if (CollectionUtil.noItems(this.outputDefs)) {
160                return Maps.of("output", responseContent);
161            } else {
162                Parameter parameter = this.outputDefs.get(0);
163                return Maps.of(parameter.getName(), responseContent);
164            }
165        }
166    }
167
168
169    /**
170     * 移除 ``` 或者 ```json 等
171     *
172     * @param markdown json内容
173     * @return 方法 json 内容
174     */
175    public static String unWrapMarkdown(String markdown) {
176        // 移除开头的 ```json 或 ```
177        if (markdown.startsWith("```")) {
178            int newlineIndex = markdown.indexOf('\n');
179            if (newlineIndex != -1) {
180                markdown = markdown.substring(newlineIndex + 1);
181            } else {
182                // 如果没有换行符,直接去掉 ``` 部分
183                markdown = markdown.substring(3);
184            }
185        }
186
187        // 移除结尾的 ```
188        if (markdown.endsWith("```")) {
189            markdown = markdown.substring(0, markdown.length() - 3);
190        }
191        return markdown.trim();
192    }
193
194
195    @Override
196    public String toString() {
197        return "LlmNode{" +
198                "llmId='" + llmId + '\'' +
199                ", chatOptions=" + chatOptions +
200                ", userPrompt='" + userPrompt + '\'' +
201                ", systemPrompt='" + systemPrompt + '\'' +
202                ", outType='" + outType + '\'' +
203                ", images=" + images +
204                ", parameters=" + parameters +
205                ", outputDefs=" + outputDefs +
206                ", id='" + id + '\'' +
207                ", name='" + name + '\'' +
208                ", description='" + description + '\'' +
209                ", condition=" + condition +
210                ", validator=" + validator +
211                ", loopEnable=" + loopEnable +
212                ", loopIntervalMs=" + loopIntervalMs +
213                ", loopBreakCondition=" + loopBreakCondition +
214                ", maxLoopCount=" + maxLoopCount +
215                ", retryEnable=" + retryEnable +
216                ", resetRetryCountAfterNormal=" + resetRetryCountAfterNormal +
217                ", maxRetryCount=" + maxRetryCount +
218                ", retryIntervalMs=" + retryIntervalMs +
219                ", computeCostExpr='" + computeCostExpr + '\'' +
220                '}';
221    }
222}