changeset 21:c93d8c781853

add functions git-svn-id: https://luan-java.googlecode.com/svn/trunk@22 21e917c8-12df-6dd8-5cb6-c86387c605b9
author fschmidt@gmail.com <fschmidt@gmail.com@21e917c8-12df-6dd8-5cb6-c86387c605b9>
date Tue, 04 Dec 2012 09:16:03 +0000
parents d85510d92eee
children 1e37f22a34c8
files src/luan/CmdLine.java src/luan/LuaClosure.java src/luan/LuaJavaFunction.java src/luan/LuaState.java src/luan/interp/Block.java src/luan/interp/Chunk.java src/luan/interp/LuaParser.java src/luan/interp/ReturnException.java src/luan/interp/ReturnStmt.java src/luan/interp/SetStmt.java src/luan/lib/BasicLib.java
diffstat 11 files changed, 264 insertions(+), 101 deletions(-) [+]
line wrap: on
line diff
diff -r d85510d92eee -r c93d8c781853 src/luan/CmdLine.java
--- a/src/luan/CmdLine.java	Sun Dec 02 10:51:18 2012 +0000
+++ b/src/luan/CmdLine.java	Tue Dec 04 09:16:03 2012 +0000
@@ -27,7 +27,7 @@
 		} else {
 			String file = args[i++];
 			try {
-				LuaFunction fn = BasicLib.loadFile(file);
+				LuaFunction fn = BasicLib.loadFile(lua,file);
 				fn.call(lua);
 			} catch(LuaException e) {
 //				System.out.println(e.getMessage());
@@ -44,7 +44,7 @@
 			System.out.print("> ");
 			String input = new Scanner(System.in).nextLine();
 			try {
-				LuaFunction fn = BasicLib.load(input);
+				LuaFunction fn = BasicLib.load(lua,input);
 				Object[] rtn = fn.call(lua);
 				if( rtn.length > 0 )
 					BasicLib.print(rtn);
diff -r d85510d92eee -r c93d8c781853 src/luan/LuaClosure.java
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/luan/LuaClosure.java	Tue Dec 04 09:16:03 2012 +0000
@@ -0,0 +1,29 @@
+package luan;
+
+import luan.interp.Chunk;
+import luan.interp.ReturnException;
+
+
+public final class LuaClosure extends LuaFunction {
+	private final Chunk chunk;
+
+	public LuaClosure(Chunk chunk,LuaState lua) {
+		this.chunk = chunk;
+	}
+
+	public Object[] call(LuaState lua,Object... args) throws LuaException {
+		Object[] stack = lua.newStack(chunk.stackSize);
+		final int n = Math.min(args.length,chunk.numArgs);
+		for( int i=0; i<n; i++ ) {
+			stack[i] = args[i];
+		}
+		try {
+			chunk.block.eval(lua);
+		} catch(ReturnException e) {
+		} finally {
+			lua.popStack();
+		}
+		return lua.returnValues;
+	}
+
+}
diff -r d85510d92eee -r c93d8c781853 src/luan/LuaJavaFunction.java
--- a/src/luan/LuaJavaFunction.java	Sun Dec 02 10:51:18 2012 +0000
+++ b/src/luan/LuaJavaFunction.java	Tue Dec 04 09:16:03 2012 +0000
@@ -9,6 +9,7 @@
 	private final Method method;
 	private final Object obj;
 	private final RtnConverter rtnConverter;
+	private final boolean takesLuaState;
 	private final ArgConverter[] argConverters;
 	private final Class<?> varArgCls;
 
@@ -16,7 +17,8 @@
 		this.method = method;
 		this.obj = obj;
 		this.rtnConverter = getRtnConverter(method);
-		this.argConverters = getArgConverters(method);
+		this.takesLuaState = takesLuaState(method);
+		this.argConverters = getArgConverters(takesLuaState,method);
 		if( method.isVarArgs() ) {
 			Class<?>[] paramTypes = method.getParameterTypes();
 			this.varArgCls = paramTypes[paramTypes.length-1].getComponentType();
@@ -26,7 +28,7 @@
 	}
 
 	@Override public Object[] call(LuaState lua,Object... args) {
-		args = fixArgs(args);
+		args = fixArgs(lua,args);
 		Object rtn;
 		try {
 			rtn = method.invoke(obj,args);
@@ -38,37 +40,39 @@
 		return rtnConverter.convert(rtn);
 	}
 
-	private Object[] fixArgs(Object[] args) {
-		if( varArgCls==null ) {
-			if( args.length != argConverters.length ) {
-				Object[] t = new Object[argConverters.length];
-				System.arraycopy(args,0,t,0,Math.min(args.length,t.length));
-				args = t;
-			}
-			for( int i=0; i<args.length; i++ ) {
-				args[i] = argConverters[i].convert(args[i]);
+	private Object[] fixArgs(LuaState lua,Object[] args) {
+		int n = argConverters.length;
+		Object[] rtn;
+		int start = 0;
+		if( !takesLuaState && varArgCls==null && args.length == n ) {
+			rtn = args;
+		} else {
+			if( takesLuaState )
+				n++;
+			rtn = new Object[n];
+			if( takesLuaState ) {
+				rtn[start++] = lua;
 			}
-			return args;
-		} else {  // varargs
-			Object[] rtn = new Object[argConverters.length];
-			int n = argConverters.length - 1;
-			if( args.length < argConverters.length ) {
-				System.arraycopy(args,0,rtn,0,args.length);
-				rtn[rtn.length-1] = Array.newInstance(varArgCls,0);
-			} else {
-				System.arraycopy(args,0,rtn,0,n);
-				Object[] varArgs = (Object[])Array.newInstance(varArgCls,args.length-n);
-				ArgConverter ac = argConverters[n];
-				for( int i=0; i<varArgs.length; i++ ) {
-					varArgs[i] = ac.convert(args[n+i]);
+			n = argConverters.length;
+			if( varArgCls != null ) {
+				n--;
+				if( args.length < argConverters.length ) {
+					rtn[rtn.length-1] = Array.newInstance(varArgCls,0);
+				} else {
+					Object[] varArgs = (Object[])Array.newInstance(varArgCls,args.length-n);
+					ArgConverter ac = argConverters[n];
+					for( int i=0; i<varArgs.length; i++ ) {
+						varArgs[i] = ac.convert(args[n+i]);
+					}
+					rtn[rtn.length-1] = varArgs;
 				}
-				rtn[rtn.length-1] = varArgs;
 			}
-			for( int i=0; i<n; i++ ) {
-				rtn[i] = argConverters[i].convert(rtn[i]);
-			}
-			return rtn;
+			System.arraycopy(args,0,rtn,start,Math.min(args.length,n));
 		}
+		for( int i=0; i<n; i++ ) {
+			rtn[start+i] = argConverters[i].convert(rtn[start+i]);
+		}
+		return rtn;
 	}
 
 
@@ -218,9 +222,19 @@
 		}
 	};
 
-	private static ArgConverter[] getArgConverters(Method m) {
+	private static boolean takesLuaState(Method m) {
+		Class<?>[] paramTypes = m.getParameterTypes();
+		return paramTypes.length > 0 && paramTypes[0].equals(LuaState.class);
+	}
+
+	private static ArgConverter[] getArgConverters(boolean takesLuaState,Method m) {
 		final boolean isVarArgs = m.isVarArgs();
 		Class<?>[] paramTypes = m.getParameterTypes();
+		if( takesLuaState ) {
+			Class<?>[] t = new Class<?>[paramTypes.length-1];
+			System.arraycopy(paramTypes,1,t,0,t.length);
+			paramTypes = t;
+		}
 		ArgConverter[] a = new ArgConverter[paramTypes.length];
 		for( int i=0; i<a.length; i++ ) {
 			Class<?> paramType = paramTypes[i];
diff -r d85510d92eee -r c93d8c781853 src/luan/LuaState.java
--- a/src/luan/LuaState.java	Sun Dec 02 10:51:18 2012 +0000
+++ b/src/luan/LuaState.java	Tue Dec 04 09:16:03 2012 +0000
@@ -20,12 +20,15 @@
 	}
 
 	private LuaStack stack = null;
+	public Object[] returnValues;
 
-	public void newStack(int stackSize) {
+	Object[] newStack(int stackSize) {
+		returnValues = LuaFunction.EMPTY_RTN;
 		stack = new LuaStack(stack,stackSize);
+		return stack.a;
 	}
 
-	public void popStack() {
+	void popStack() {
 		stack = stack.previousStack;
 	}
 
diff -r d85510d92eee -r c93d8c781853 src/luan/interp/Block.java
--- a/src/luan/interp/Block.java	Sun Dec 02 10:51:18 2012 +0000
+++ b/src/luan/interp/Block.java	Tue Dec 04 09:16:03 2012 +0000
@@ -5,7 +5,7 @@
 
 
 final class Block implements Stmt {
-	private final Stmt[] stmts;
+	final Stmt[] stmts;
 	private final int stackStart;
 	private final int stackEnd;
 
diff -r d85510d92eee -r c93d8c781853 src/luan/interp/Chunk.java
--- a/src/luan/interp/Chunk.java	Sun Dec 02 10:51:18 2012 +0000
+++ b/src/luan/interp/Chunk.java	Tue Dec 04 09:16:03 2012 +0000
@@ -2,24 +2,37 @@
 
 import luan.LuaState;
 import luan.LuaException;
+import luan.LuaClosure;
 
 
-final class Chunk implements Stmt {
-	private final Stmt block;
-	private final int stackSize;
+public final class Chunk implements Expr {
+	public final Stmt block;
+	public final int stackSize;
+	public final int numArgs;
 
-	Chunk(Stmt block,int stackSize) {
+	Chunk(Stmt block,int stackSize,int numArgs) {
 		this.block = block;
 		this.stackSize = stackSize;
-	}
-
-	@Override public void eval(LuaState lua) throws LuaException {
-		lua.newStack(stackSize);
-		try {
-			block.eval(lua);
-		} finally {
-			lua.popStack();
+		this.numArgs = numArgs;
+		Stmt stmt = block;
+		while( stmt instanceof Block ) {
+			Block b = (Block)stmt;
+			if( b.stmts.length==0 )
+				break;
+			stmt = b.stmts[b.stmts.length-1];
+		}
+		if( stmt instanceof ReturnStmt ) {
+			ReturnStmt rs = (ReturnStmt)stmt;
+			rs.throwReturnException = false;
 		}
 	}
 
+	public LuaClosure newClosure(LuaState lua) {
+		return new LuaClosure(this,lua);
+	}
+
+	@Override public Object eval(LuaState lua) {
+		return newClosure(lua);
+	}
+
 }
diff -r d85510d92eee -r c93d8c781853 src/luan/interp/LuaParser.java
--- a/src/luan/interp/LuaParser.java	Sun Dec 02 10:51:18 2012 +0000
+++ b/src/luan/interp/LuaParser.java	Tue Dec 04 09:16:03 2012 +0000
@@ -22,11 +22,21 @@
 
 
 public class LuaParser extends BaseParser<Object> {
+
+	static final class Frame {
+		final Frame parent;
+		final List<String> symbols = new ArrayList<String>();
+		int stackSize = 0;
+		int loops = 0;
+
+		Frame(Frame parent) {
+			this.parent = parent;
+		}
+	}
+
 	int nEquals;
 	int parens = 0;
-	List<String> symbols = new ArrayList<String>();
-	int stackSize = 0;
-	int loops = 0;
+	Frame frame = new Frame(null);
 
 	boolean nEquals(int n) {
 		nEquals = n;
@@ -43,7 +53,23 @@
 		return true;
 	}
 
+	List<String> symbols() {
+		return frame.symbols;
+	}
+
+	int symbolsSize() {
+		return frame.symbols.size();
+	}
+
+	boolean addSymbol(String name) {
+		frame.symbols.add(name);
+		if( frame.stackSize < symbolsSize() )
+			frame.stackSize = symbolsSize();
+		return true;
+	}
+
 	int index(String name) {
+		List<String> symbols = frame.symbols;
 		int i = symbols.size();
 		while( --i >= 0 ) {
 			if( symbols.get(i).equals(name) )
@@ -53,6 +79,7 @@
 	}
 
 	boolean popSymbols(int n) {
+		List<String> symbols = frame.symbols;
 		while( n-- > 0 ) {
 			symbols.remove(symbols.size()-1);
 		}
@@ -60,12 +87,12 @@
 	}
 
 	boolean incLoops() {
-		loops++;
+		frame.loops++;
 		return true;
 	}
 
 	boolean decLoops() {
-		loops--;
+		frame.loops--;
 		return true;
 	}
 
@@ -75,21 +102,18 @@
 			Spaces(),
 			FirstOf(
 				Sequence( ExpList(), EOI ),
-				Sequence( Chunk(), EOI )
+				Sequence(
+					Block(),
+					EOI,
+					push( new Chunk( (Stmt)pop(), frame.stackSize, 0 ) )
+				)
 			)
 		);
 	}
 
-	Rule Chunk() {
-		return Sequence(
-			Block(),
-			push( new Chunk( (Stmt)pop(), stackSize ) )
-		);
-	}
-
 	Rule Block() {
 		Var<List<Stmt>> stmts = new Var<List<Stmt>>(new ArrayList<Stmt>());
-		Var<Integer> stackStart = new Var<Integer>(symbols.size());
+		Var<Integer> stackStart = new Var<Integer>(symbolsSize());
 		return Sequence(
 			Optional( Stmt(stmts) ),
 			ZeroOrMore(
@@ -101,9 +125,7 @@
 	}
 
 	Stmt newBlock(List<Stmt> stmts,int stackStart) {
-		if( stackSize < symbols.size() )
-			stackSize = symbols.size();
-		int stackEnd = symbols.size();
+		int stackEnd = symbolsSize();
 		popSymbols( stackEnd - stackStart );
 		if( stmts.isEmpty() )
 			return Stmt.EMPTY;
@@ -134,6 +156,9 @@
 			LocalStmt(stmts),
 			Sequence(
 				FirstOf(
+					ReturnStmt(),
+					FunctionStmt(),
+					LocalFunctionStmt(),
 					BreakStmt(),
 					GenericForStmt(),
 					NumericForStmt(),
@@ -149,20 +174,54 @@
 		);
 	}
 
+	Rule ReturnStmt() {
+		return Sequence(
+			Keyword("return"), Expressions(),
+			push( new ReturnStmt( (Expressions)pop() ) )
+		);
+	}
+
+	Rule FunctionStmt() {
+		return Sequence(
+			Keyword("function"), FnName(), Function(),
+			push( new SetStmt( (Settable)pop(1), expr(pop()) ) )
+		);
+	}
+
+	Rule FnName() {
+		return Sequence(
+			push(null),  // marker
+			Name(),
+			ZeroOrMore(
+				'.', Spaces(),
+				makeVarExp(),
+				NameExpr()
+			),
+			makeSettableVar()
+		);
+	}
+
+	Rule LocalFunctionStmt() {
+		return Sequence(
+			Keyword("local"), Keyword("function"), LocalName(), Function(),
+			push( new SetStmt( new SetLocalVar(symbolsSize()-1), expr(pop()) ) )
+		);
+	}
+
 	Rule BreakStmt() {
 		return Sequence(
 			Keyword("break"),
-			loops > 0,
+			frame.loops > 0,
 			push( new BreakStmt() )
 		);
 	}
 
 	Rule GenericForStmt() {
-		Var<Integer> stackStart = new Var<Integer>(symbols.size());
+		Var<Integer> stackStart = new Var<Integer>(symbolsSize());
 		return Sequence(
 			Keyword("for"), NameList(), Keyword("in"), Expr(), Keyword("do"), LoopBlock(), Keyword("end"),
-			push( new GenericForStmt( stackStart.get(), symbols.size() - stackStart.get(), expr(pop(1)), (Stmt)pop() ) ),
-			popSymbols( symbols.size() - stackStart.get() )
+			push( new GenericForStmt( stackStart.get(), symbolsSize() - stackStart.get(), expr(pop(1)), (Stmt)pop() ) ),
+			popSymbols( symbolsSize() - stackStart.get() )
 		);
 	}
 
@@ -175,9 +234,9 @@
 				drop(),
 				Expr()
 			),
-			symbols.add( (String)pop(3) ),  // add "for" var to symbols
+			addSymbol( (String)pop(3) ),  // add "for" var to symbols
 			Keyword("do"), LoopBlock(), Keyword("end"),
-			push( new NumericForStmt( symbols.size()-1, expr(pop(3)), expr(pop(2)), expr(pop(1)), (Stmt)pop() ) ),
+			push( new NumericForStmt( symbolsSize()-1, expr(pop(3)), expr(pop(2)), expr(pop(1)), (Stmt)pop() ) ),
 			popSymbols(1)
 		);
 	}
@@ -189,7 +248,7 @@
 	}
 
 	Rule LocalStmt(Var<List<Stmt>> stmts) {
-		Var<Integer> stackStart = new Var<Integer>(symbols.size());
+		Var<Integer> stackStart = new Var<Integer>(symbolsSize());
 		return Sequence(
 			Keyword("local"), NameList(),
 			Optional(
@@ -201,18 +260,23 @@
 
 	Rule NameList() {
 		return Sequence(
-			Name(),
-			symbols.add( (String)pop() ),
+			LocalName(),
 			ZeroOrMore(
-				',', Spaces(), Name(),
-				symbols.add( (String)pop() )
+				',', Spaces(), LocalName()
 			)
 		);
 	}
 
+	Rule LocalName() {
+		return Sequence(
+			Name(),
+			addSymbol( (String)pop() )
+		);
+	}
+
 	SetStmt newSetLocalStmt(int stackStart) {
 		Expressions values = (Expressions)pop();
-		SetLocalVar[] vars = new SetLocalVar[symbols.size()-stackStart];
+		SetLocalVar[] vars = new SetLocalVar[symbolsSize()-stackStart];
 		for( int i=0; i<vars.length; i++ ) {
 			vars[i] = new SetLocalVar(stackStart+i);
 		}
@@ -289,36 +353,39 @@
 	}
 
 	Rule VarList() {
+		Var<List<Settable>> vars = new Var<List<Settable>>(new ArrayList<Settable>());
 		return Sequence(
-			push(new ArrayList<Settable>()),
-			Var(),
-			addToVarList(),
+			SettableVar(),
+			vars.get().add( (Settable)pop() ),
 			ZeroOrMore(
-				',', Spaces(), Var(),
-				addToVarList()
-			)
+				',', Spaces(), SettableVar(),
+				vars.get().add( (Settable)pop() )
+			),
+			push(vars.get())
 		);
 	}
 
-	boolean addToVarList() {
+	Rule SettableVar() {
+		return Sequence( Var(), makeSettableVar() );
+	}
+
+	boolean makeSettableVar() {
 		Object obj2 = pop();
 		if( obj2==null )
 			return false;
 		Object obj1 = pop();
-		@SuppressWarnings("unchecked")
-		List<Settable> vars = (List<Settable>)peek();
 		if( obj1==null ) {
 			String name = (String)obj2;
 			int index = index(name);
 			if( index == -1 ) {
-				vars.add( new SetTableEntry( EnvExpr.INSTANCE, new ConstExpr(name) ) );
+				push( new SetTableEntry( EnvExpr.INSTANCE, new ConstExpr(name) ) );
 			} else {
-				vars.add( new SetLocalVar(index) );
+				push( new SetLocalVar(index) );
 			}
 		} else {
 			Expr key = expr(obj2);
 			Expr table = expr(obj1);
-			vars.add( new SetTableEntry(table,key) );
+			push( new SetTableEntry(table,key) );
 		}
 		return true;
 	}
@@ -407,12 +474,26 @@
 
 	Rule SingleExpr() {
 		return FirstOf(
+			FunctionExpr(),
 			TableExpr(),
 			VarExp(),
 			LiteralExpr()
 		);
 	}
 
+	Rule FunctionExpr() {
+		return Sequence( "function", Spaces(), Function() );
+	}
+
+	Rule Function() {
+		return Sequence(
+			action( frame = new Frame(frame) ),
+			'(', incParens(), Spaces(), Optional(NameList()), ')', decParens(), Spaces(), Block(), Keyword("end"),
+			push( new Chunk( (Stmt)pop(), frame.stackSize, symbolsSize() ) ),
+			action( frame = frame.parent )
+		);
+	}
+
 	Rule TableExpr() {
 		return Sequence(
 			'{', incParens(), Spaces(),
diff -r d85510d92eee -r c93d8c781853 src/luan/interp/ReturnException.java
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/luan/interp/ReturnException.java	Tue Dec 04 09:16:03 2012 +0000
@@ -0,0 +1,4 @@
+package luan.interp;
+
+
+public final class ReturnException extends RuntimeException {}
diff -r d85510d92eee -r c93d8c781853 src/luan/interp/ReturnStmt.java
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/luan/interp/ReturnStmt.java	Tue Dec 04 09:16:03 2012 +0000
@@ -0,0 +1,20 @@
+package luan.interp;
+
+import luan.LuaState;
+import luan.LuaException;
+
+
+final class ReturnStmt implements Stmt {
+	private final Expressions expressions;
+	boolean throwReturnException = true;
+
+	ReturnStmt(Expressions expressions) {
+		this.expressions = expressions;
+	}
+
+	@Override public void eval(LuaState lua) throws LuaException {
+		lua.returnValues = expressions.eval(lua);
+		if( throwReturnException )
+			throw new ReturnException();
+	}
+}
diff -r d85510d92eee -r c93d8c781853 src/luan/interp/SetStmt.java
--- a/src/luan/interp/SetStmt.java	Sun Dec 02 10:51:18 2012 +0000
+++ b/src/luan/interp/SetStmt.java	Tue Dec 04 09:16:03 2012 +0000
@@ -9,6 +9,10 @@
 	private final Settable[] vars;
 	private final Expressions expressions;
 
+	SetStmt(Settable var,Expr expr) {
+		this( new Settable[]{var}, new ExpList.SingleExpList(expr) );
+	}
+
 	SetStmt(Settable[] vars,Expressions expressions) {
 		this.vars = vars;
 		this.expressions = expressions;
diff -r d85510d92eee -r c93d8c781853 src/luan/lib/BasicLib.java
--- a/src/luan/lib/BasicLib.java	Sun Dec 02 10:51:18 2012 +0000
+++ b/src/luan/lib/BasicLib.java	Tue Dec 04 09:16:03 2012 +0000
@@ -19,7 +19,7 @@
 import luan.LuaException;
 import luan.interp.LuaParser;
 import luan.interp.Expressions;
-import luan.interp.Stmt;
+import luan.interp.Chunk;
 
 
 public class BasicLib {
@@ -28,8 +28,8 @@
 		LuaTable t = lua.env();
 		add( t, "print", new Object[0].getClass() );
 		add( t, "type", Object.class );
-		add( t, "load", String.class );
-		add( t, "loadFile", String.class );
+		add( t, "load", LuaState.class, String.class );
+		add( t, "loadFile", LuaState.class, String.class );
 		add( t, "pairs", LuaTable.class );
 		add( t, "ipairs", LuaTable.class );
 	}
@@ -55,7 +55,7 @@
 		return Lua.type(obj);
 	}
 
-	public static LuaFunction load(String ld) throws LuaException {
+	public static LuaFunction load(LuaState lua,String ld) throws LuaException {
 		LuaParser parser = Parboiled.createParser(LuaParser.class);
 		ParsingResult<?> result = new ReportingParseRunner(parser.Target()).run(ld);
 //		ParsingResult<?> result = new TracingParseRunner(parser.Target()).run(ld);
@@ -70,13 +70,8 @@
 				}
 			};
 		}
-		final Stmt stmt = (Stmt)resultValue;
-		return new LuaFunction() {
-			public Object[] call(LuaState lua,Object... args) throws LuaException {
-				stmt.eval(lua);
-				return LuaFunction.EMPTY_RTN;
-			}
-		};
+		Chunk chunk = (Chunk)resultValue;
+		return chunk.newClosure(lua);
 	}
 
 	public static String readFile(String fileName) throws IOException {
@@ -90,8 +85,8 @@
 		return sb.toString();
 	}
 
-	public static LuaFunction loadFile(String fileName) throws LuaException,IOException {
-		return load(readFile(fileName));
+	public static LuaFunction loadFile(LuaState lua,String fileName) throws LuaException,IOException {
+		return load(lua,readFile(fileName));
 	}
 
 	private static class TableIter {