// Copyright (c) 2001 Hursh Jain (http://www.mollypages.org) 
// The Molly framework is freely distributable under the terms of an
// MIT-style license. For details, see the molly pages web site at:
// http://www.mollypages.org/. Use, modify, have fun !

package fc.jdbc;

import java.io.*;
import java.util.*;
import java.math.*;
import java.util.regex.*;
import java.sql.*;
import javax.servlet.*;

import fc.io.*;
import fc.web.*;
import fc.util.*;

/** 
This class wraps around {@link PreparedStatement} and allows the programmer to set
parameters by name instead of by question mark index.
<p>
Inspired by a similar concept at: <a href="http://www.javaworld.com/javaworld/jw-04-2007/jw-04-jdbc.html">javaworld</a> 
(although this class was rewritten from scratch).
<p>
Named parameters are written as <i>@foo</i> (that is, they start with a <i>@</i> character). Named parameters can only contain <i>alphanumeric</i>, <i>underscore</i> and <i>dashes</i>, any character not in this allowed list automatically ends the named parameter and continues normal SQL. For example:
<blockquote><pre>
select * from foo where search = @search-term and radius = '@radius::int'
</pre></blockquote>
This contains two named parameters, <font color=blue>@search-term</font> and <font color=blue>@radius</font>. To use this in code, we say:
<blockquote><pre>

Connection con = getConnection(); //some way to get a connection
String query = 
  "select * from foo where search = <font color=blue>@search-term</font> and radius = '<font color=blue>@radius</font>::int'";

NamedParamStatement ps = NamedParamStatement.get(con, query);
ps.setString("<b>search-term</b>", "hello world");
ps.setInt("<b>radius</b>", 42);

ResultSet rs = ps.executeQuery();
</pre></blockquote>
<p>
<b>Note</b>: When setting a named paramter, the "@" must be omitted.
<p>
The same named parameter can appear multiple times in the query, and is replaced
wherever it appears by its value.
<p>
The {@link #close} method should be called to release resources and help with
garbage collection (right around the time close is called on the 
associated connection, which is after any retrieved data/resultset has
been fully read).

@author hursh jain
**/
public final class NamedParamStatement
{
private static final boolean dbg = false;

//associated with each original query, is this parse data.
//for example, an original query of:
//	select * from foo where bar = @bar
//becomes
//	parsedQuery = select * from foo where bar = ?
//  indexMap    = {bar: 1}
//
private static class ParseData
	{
	String parsedQuery;
	Map	   indexMap;
	
	public String toString() {
		return "parsedQuery:" + parsedQuery + "; indexMap:" + indexMap;
		}
	}
	
//{unparsed_query: parsedata}
private static final Map queryMap = new HashMap(); //static initialized

private PreparedStatement 	wrappedPS; 
private Map					indexMap; //a convenient ref to query->ParseData->indexMap
private NamedParamStatement(PreparedStatement wrappedPS, Map indexMap)
	{
	this.wrappedPS = wrappedPS;
	this.indexMap = indexMap;
	}
	
/**
Returns a new instance of NamedParamStatement. This instance, internally, 
creates/wraps a new prepared statement. (the query string is not reparsed 
every time, just the first time this method is invoked for any particular
query). The returned query is not scrollable (use {@link #getScrollable} for
a scrollable result set).
*/
public static NamedParamStatement get(Connection con, String query) throws SQLException
	{
	return add(con, query, false);
	}

/**
Returns a new instance of NamedParamStatement. This instance, internally, 
creates/wraps a new {@link PreparedStatement}. (the query string is not reparsed 
every time, just the first time this method is invoked for any particular
query). 
<p>
This method ensures that any {@link ResultSet} returned by the wrapped
PreparedStatement is scrollable (the  <tt>ResultSet.TYPE_SCROLL_INSENSITIVE</tt> 
flag is used when creating the PreparedStatement).
*/
public static NamedParamStatement getScrollable(Connection con, String query) throws SQLException
	{
	return add(con, query, true);
	}

//dont capture the "@", name are put in the map without the @ char
static final Pattern pat = Pattern.compile("(\\s*)@([a-zA-Z_0-9.]+)");

private static final ParseData parse(String query)
	{
	if (dbg) System.out.println("Analyzing query: \n" + query);

	Map indexMap = new HashMap();
	
	Matcher match = pat.matcher(query);
	StringBuffer sb = new StringBuffer(query.length());
	
	int pos = 1;  //? positions start from 1	
	while (match.find()) 
		{
		match.appendReplacement(sb, "$1?");
		String paramName = match.group(2);
		if (dbg) System.out.println("Found replacement name: @" + paramName);

		ArrayList indexes = (ArrayList) indexMap.get(paramName);
		if (indexes == null) {
			indexes = new ArrayList();		
			indexMap.put(paramName, indexes);
			}
		indexes.add(pos++);
		}
	match.appendTail(sb);

	String parsedQuery = sb.toString();

	if (dbg) System.out.println("Replacement index map: \n" + indexMap);
	if (dbg) System.out.println("Replaced query: " + parsedQuery);	
	
	ParseData pd = new ParseData();
	pd.indexMap = indexMap;
	pd.parsedQuery = parsedQuery;
	
	queryMap.put(query, pd);
	
	return pd;
	}

private static final NamedParamStatement add(Connection con, String query, boolean scrollable) 
throws SQLException
	{
	if (query == null) {
		throw new IllegalArgumentException("'query' parameter was null");	
		}

	//not synchronized, because in worst case, we will add (reparse) the
	//statement, which is no biggie. 

	ParseData pd = (ParseData) queryMap.get(query);
	if (pd == null) {
		pd = parse(query); //this also populates queryMap for next time
		}
		
	PreparedStatement ps = null;
	
	if (scrollable) {
		ps = con.prepareStatement(pd.parsedQuery, 
								ResultSet.TYPE_SCROLL_INSENSITIVE,
								ResultSet.CONCUR_READ_ONLY);
		}
	else{
		ps = con.prepareStatement(pd.parsedQuery);
		}

	//cannot save these, these wrap a connection specific preparedstatement, which
	//must be recreated per connection. We cache the expensive parts (indexMap)
	//above tho, which does not have to be recreated if the querystring is the
	//same.
	NamedParamStatement np = new NamedParamStatement(ps, pd.indexMap);
	
	return np;
	}

private List getIndexes(String name) 
	{
	List indexes = (List) indexMap.get(name);

	if(indexes == null) {
		throw new IllegalArgumentException("NamedParamStatement [" + wrappedPS.toString() + "], replacement parameter not found, parameter name=" + name);
		}
		
	return indexes;
	}

public void close() throws SQLException 
	{
	wrappedPS.close();
	wrappedPS = null;
	}

public String toString() 
	{
	return wrappedPS.toString();
	}

//========================== wrapped methods ===========================

public ResultSet executeQuery() throws SQLException
	{
	return wrappedPS.executeQuery();
	}
	
public int executeUpdate() throws SQLException
	{
	return wrappedPS.executeUpdate();	
	}

public void setNull(String name, int sqlType) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setNull(((Integer)indexes.get(i)).intValue(), sqlType);
		}
	}

public void setBoolean(String name, boolean x) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setBoolean(((Integer)indexes.get(i)).intValue(), x);
		}	
	}
	
public void setByte(String name, byte x) throws SQLException 
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setByte(((Integer)indexes.get(i)).intValue(), x);
		}	
	}

public void setShort(String name, short x) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setShort(((Integer)indexes.get(i)).intValue(), x);
		}	
	}
	
public void setInt(String name, int x) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setInt(((Integer)indexes.get(i)).intValue(), x);
		}		
	}
	
public void setLong(String name, long x) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setLong(((Integer)indexes.get(i)).intValue(), x);
		}		
	}
	
public void setFloat(String name, float x) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setFloat(((Integer)indexes.get(i)).intValue(), x);
		}		
	}

public void setDouble(String name, double x) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setDouble(((Integer)indexes.get(i)).intValue(), x);
		}		
	}

public void setBigDecimal(String name, BigDecimal x) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setBigDecimal(((Integer)indexes.get(i)).intValue(), x);
		}		
	}

public void setString(String name, String x) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setString(((Integer)indexes.get(i)).intValue(), x);
		}		
	}

public void setBytes(String name, byte x[]) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setBytes(((Integer)indexes.get(i)).intValue(), x);
		}		
	}

public void setDate(String name, java.sql.Date x) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setDate(((Integer)indexes.get(i)).intValue(), x);
		}		
	}

public void setTime(String name, java.sql.Time x) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setTime(((Integer)indexes.get(i)).intValue(), x);
		}		
	}

public void setTimestamp(String name, java.sql.Timestamp x) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setTimestamp(((Integer)indexes.get(i)).intValue(), x);
		}		
	}

public void setAsciiStream(String name, java.io.InputStream x, int length) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setAsciiStream(((Integer)indexes.get(i)).intValue(), x, length);
		}		
	}

public void setBinaryStream(String name, java.io.InputStream x, int length) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setBinaryStream(((Integer)indexes.get(i)).intValue(), x, length);
		}		
	}

public void clearParameters() throws SQLException
	{
	wrappedPS.clearParameters();
	}

public void setObject(String name, Object x, int targetSqlType, int scale) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setObject(((Integer)indexes.get(i)).intValue(), x, targetSqlType, scale);
		}		
	}

public void setObject(String name, Object x, int targetSqlType) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setObject(((Integer)indexes.get(i)).intValue(), x, targetSqlType);
		}		
	}

public void setObject(String name, Object x) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setObject(((Integer)indexes.get(i)).intValue(), x);
		}		
	}

public boolean execute() throws SQLException
	{
	return wrappedPS.execute();
	}

public void addBatch() throws SQLException
	{
	wrappedPS.addBatch();	
	}
	
public void setCharacterStream(String name, java.io.Reader reader, int length) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setCharacterStream(((Integer)indexes.get(i)).intValue(), reader, length);
		}		
	}

public void setRef (String name, Ref x) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setRef(((Integer)indexes.get(i)).intValue(), x);
		}		
	}

public void setBlob (String name, Blob x) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setBlob(((Integer)indexes.get(i)).intValue(), x);
		}		
	}

public void setClob (String name, Clob x) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setClob(((Integer)indexes.get(i)).intValue(), x);
		}		
	}

public void setArray (String name, Array x) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setArray(((Integer)indexes.get(i)).intValue(), x);
		}		
	}

public ResultSetMetaData getMetaData() throws SQLException
	{
	return wrappedPS.getMetaData();
	}

public void setDate(String name, java.sql.Date x, Calendar cal) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setDate(((Integer)indexes.get(i)).intValue(), x, cal);
		}		
	}

public void setTime(String name, java.sql.Time x, Calendar cal) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setTime(((Integer)indexes.get(i)).intValue(), x, cal);
		}		
	}

public void setTimestamp(String name, java.sql.Timestamp x, Calendar cal) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setTimestamp(((Integer)indexes.get(i)).intValue(), x, cal);
		}		
	}

public void setNull(String name, int sqlType, String typeName) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setNull(((Integer)indexes.get(i)).intValue(), sqlType, typeName);
		}		
	}

public void setURL(String name, java.net.URL x) throws SQLException
	{
	final List indexes = getIndexes(name);
	for(int i = 0; i < indexes.size(); i++) {
		wrappedPS.setURL(((Integer)indexes.get(i)).intValue(), x);
		}		
	}

public ParameterMetaData getParameterMetaData() throws SQLException
	{
	return wrappedPS.getParameterMetaData();
	}

}