view src/lib/ai/claude/Ai_chat.luan @ 70:4a73af8f2203 default tip

fix
author Franklin Schmidt <fschmidt@gmail.com>
date Sat, 23 Aug 2025 12:00:16 -0600
parents f5e72f2d1025
children
line wrap: on
line source

local Luan = require "luan:Luan.luan"
local error = Luan.error
local ipairs = Luan.ipairs or error()
local pairs = Luan.pairs or error()
local type = Luan.type or error()
local String = require "luan:String.luan"
local starts_with = String.starts_with or error()
local Html = require "luan:Html.luan"
local html_encode = Html.encode or error()
local Parsers = require "luan:Parsers.luan"
local json_parse = Parsers.json_parse or error()
local json_string = Parsers.json_string or error()
local Claude = require "site:/lib/ai/claude/Claude.luan"
local claude_chat = Claude.chat or error()
local Logging = require "luan:logging/Logging.luan"
local logger = Logging.logger "claude/Ai_chat"


local Ai_chat = {}

function Ai_chat.output_system_prompt(thread)
	thread = json_parse(thread)
	local system_prompt = thread.system or error
	system_prompt = html_encode(system_prompt)
	%><%=system_prompt%><%
end

function Ai_chat.output_messages_html(assistant_controls,thread,old_thread)
	thread = json_parse(thread)
	local messages = thread.messages or error
	local n = 0
	if old_thread ~= nil then
		old_thread = json_parse(old_thread)
		local old_messages = old_thread.messages or error
		n = #old_messages
	end
	for i, message in ipairs(messages) do
		if i <= n then
			continue
		end
		local role = message.role or error()
		local who
		if role=="assistant" then
			who = "Claude"
		elseif role=="user" then
			who = "You"
		else
			error(role)
		end
		local function output(text)
			text = html_encode(text)
%>
			<h3><%=who%></h3>
			<div role="<%=role%>">
				<div message markdown><%=text%></div>
<%			if role=="assistant" then %>
<%=				assistant_controls %>
<%			end %>
			</div>
<%
		end
		local content = message.content or error()
		if type(content) == "string" then
			output(content)
		else
			for _, part in ipairs(content) do
				if part.type=="text" then
					local text = part.text or error()
					output(text)
				end
			end
		end
	end_for
end

local function get_chat(chat_id)
	local Chat = require "site:/lib/Chat.luan"
	local User = require "site:/lib/User.luan"
	local chat = Chat.get_by_id(chat_id) or error()
	local user = User.current()
	local is_owner = user ~= nil and user.id == chat.user_id
	is_owner or not chat.is_private or error "private"
	return chat
end

local functions = {
	get_chat = {
		tool = {
			description = "Get the contents of a chat/thread with Claude on this website.  The contents will be JSON in the format of the Claude API."
			input_schema = {
				type = "object"
				properties = {
					chat_id = {
						description = "The ID of the chat"
						type = "integer"
					}
				}
			}
		}
		fn = function(input)
			local chat_id = input.chat_id or error()
			local chat = get_chat(chat_id)
			return chat.ai_thread or error()
		end
	}
	get_tts_instructions = {
		tool = {
			description = "Get the text-to-speech instructions of a chat/thread on this website.  These instructions are passed to OpenAI.  If there are no instructions, the empty string is returned."
			input_schema = {
				type = "object"
				properties = {
					chat_id = {
						description = "The ID of the chat"
						type = "integer"
					}
				}
			}
		}
		fn = function(input)
			local chat_id = input.chat_id or error()
			local chat = get_chat(chat_id)
			return chat.tts_instructions or error()
		end
	}
	get_stt_prompt = {
		tool = {
			description = "Get the speech-to-text prompt of a chat/thread on this website.  This prompt is passed to OpenAI.  If there is no prompt, the empty string is returned."
			input_schema = {
				type = "object"
				properties = {
					chat_id = {
						description = "The ID of the chat"
						type = "integer"
					}
				}
			}
		}
		fn = function(input)
			local chat_id = input.chat_id or error()
			local chat = get_chat(chat_id)
			return chat.stt_prompt or error()
		end
	}
}
local tools = {nil}
for name, f in pairs(functions) do
	f.name = name
	f.tool.name = name
	tools[#tools+1] = f.tool
end

function Ai_chat.init(system_prompt)
	local thread = {
		system = system_prompt
		tools = tools
		messages = {nil}
	}
	return json_string(thread)
end

function Ai_chat.has_messages(thread)
	thread = json_parse(thread)
	return #thread.messages > 0
end

local function ask(thread,content)
	local messages = thread.messages or error
	messages[#messages+1] = {
		role = "user"
		content = content
	}
--[=[
	messages[#messages+1] = {
		role = "assistant"
		content = [[
hello
]]
	}
	if true then
		return
	end
--]=]
	-- logger.info(json_string(thread))
	local resultJson = claude_chat(thread)
	local result = json_parse(resultJson)
	-- logger.info(json_string(result))
	result.type == "message" or error()
	result.role == "assistant" or error()
	result.stop_reason == "end_turn" or result.stop_reason == "tool_use" or error()
	local content = result.content or error()
	messages[#messages+1] = {
		role = "assistant"
		content = content
	}
	local stop_reason = result.stop_reason or error()
	if stop_reason == "end_turn" then
		-- ok
	elseif stop_reason == "tool_use" then
		local response = {nil}
		for _, part in ipairs(content) do
			if part.type == "tool_use" then
				local f = functions[part.name] or error()
				local input = part.input or error()
				response[#response+1] = {
					type = "tool_result"
					tool_use_id = part.id or error()
					content = f.fn(input)
				}
			end
		end
		ask(thread,response)
	else
		error(stop_reason)
	end
end

function Ai_chat.ask(thread,input)
	thread = json_parse(thread)
	ask(thread,input)
	return json_string(thread)
end

return Ai_chat