下面是继承了Interceptor的插件类
package dwz.common.mybatis; import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.List; import java.util.Properties; import javax.xml.bind.PropertyException; import org.apache.commons.lang3.StringUtils; import org.apache.ibatis.binding.MapperMethod; import org.apache.ibatis.executor.parameter.ParameterHandler; import org.apache.ibatis.executor.statement.RoutingStatementHandler; import org.apache.ibatis.executor.statement.StatementHandler; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.mapping.ParameterMapping; import org.apache.ibatis.plugin.Interceptor; import org.apache.ibatis.plugin.Intercepts; import org.apache.ibatis.plugin.Invocation; import org.apache.ibatis.plugin.Plugin; import org.apache.ibatis.plugin.Signature; import org.apache.ibatis.scripting.defaults.DefaultParameterHandler; import dwz.common.mybatis.Page; import dwz.common.util.ReflectUtil; @Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class }) }) @SuppressWarnings("rawtypes") public class PageInterceptor implements Interceptor { private static String databaseType ="";// 数据库类型,不同的数据库有不同的分页方法 /** * 拦截后要执行的方法 */ public Object intercept(Invocation invocation) throws Throwable { RoutingStatementHandler handler = (RoutingStatementHandler) invocation .getTarget(); StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate"); BoundSql boundSql = delegate.getBoundSql(); Object params = boundSql.getParameterObject(); Page page = null; if (params instanceof Page) { page = (Page) params; } else if (params instanceof MapperMethod.ParamMap) { MapperMethod.ParamMap paramMap = (MapperMethod.ParamMap) params; for (Object key : paramMap.keySet()) { if (paramMap.get(key) instanceof Page) { page = (Page) paramMap.get(key); break; } } } if (page != null) { MappedStatement mappedStatement = (MappedStatement) ReflectUtil .getFieldValue(delegate, "mappedStatement"); Connection connection = (Connection) invocation.getArgs()[0]; String sql = boundSql.getSql(); this.setTotalRecord(page, (MapperMethod.ParamMap) params, mappedStatement, connection); String pageSql = this.getPageSql(page, sql); ReflectUtil.setFieldValue(boundSql, "sql", pageSql); } return invocation.proceed(); } /** * 拦截器对应的封装原始对象的方法 */ public Object plugin(Object target) { return Plugin.wrap(target, this); } public void setProperties(Properties p) { databaseType = p.getProperty("databaseType"); if (StringUtils.isEmpty(databaseType)) { try { throw new PropertyException("databaseType is not found!"); } catch (PropertyException e) { e.printStackTrace(); } } } private String getPageSql(Page<?> page, String sql) { StringBuffer sqlBuffer = new StringBuffer(sql); if ("mysql".equalsIgnoreCase(databaseType)) { return getMysqlPageSql(page, sqlBuffer); } else if ("oracle".equalsIgnoreCase(databaseType)) { return getOraclePageSql(page, sqlBuffer); } else if ("sqlserver".equalsIgnoreCase(databaseType)) { return getSqlserverPageSql(page, sqlBuffer); } return sqlBuffer.toString(); } private String getSqlserverPageSql(Page<?> page, StringBuffer sqlBuffer) { // 计算第一条记录的位置,Sqlserver中记录的位置是从0开始的。 int startRowNum = (page.getPageNum() - 1) * page.getPageSize() + 1; int endRowNum = startRowNum + page.getPageSize(); String sql = "select appendRowNum.row,* from (select ROW_NUMBER() OVER (order by (select 0)) AS row,* from (" + sqlBuffer.toString() + ") as innerTable" + ")as appendRowNum where appendRowNum.row >= " + startRowNum + " AND appendRowNum.row <= " + endRowNum; return sql; } private String getMysqlPageSql(Page<?> page, StringBuffer sqlBuffer) { // 计算第一条记录的位置,Mysql中记录的位置是从0开始的。 int offset = (page.getPageNum() - 1) * page.getPageSize(); sqlBuffer.append(" limit ").append(offset).append(",").append(page.getPageSize()); return sqlBuffer.toString(); } private String getOraclePageSql(Page<?> page, StringBuffer sqlBuffer) { // 计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的 int offset = (page.getPageNum() - 1) * page.getPageSize() + 1; sqlBuffer.insert(0, "select u.*, rownum r from (").append(") u where rownum < ") .append(offset + page.getPageSize()); sqlBuffer.insert(0, "select * from (").append(") where r >= ").append(offset); return sqlBuffer.toString(); } /** * 给当前的参数对象page设置总记录数 * * @param page * Mapper映射语句对应的参数对象 * @param mappedStatement * Mapper映射语句 * @param connection */ private void setTotalRecord(Page<?> page, MapperMethod.ParamMap params, MappedStatement mappedStatement, Connection connection) { BoundSql boundSql = mappedStatement.getBoundSql(params); String sql = boundSql.getSql(); String countSql = this.getCountSql(sql); List<ParameterMapping> parameterMappings = boundSql.getParameterMappings(); BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql,parameterMappings, params); ParameterHandler parameterHandler = new DefaultParameterHandler( mappedStatement, params, countBoundSql); PreparedStatement pstmt = null; ResultSet rs = null; try { pstmt = connection.prepareStatement(countSql); parameterHandler.setParameters(pstmt); rs = pstmt.executeQuery(); if (rs.next()) { int totalRecord = rs.getInt(1); page.setTotalRecord(totalRecord); } } catch (SQLException e) { e.printStackTrace(); } finally { try { if (rs != null)rs.close(); if (pstmt != null)pstmt.close(); } catch (SQLException e) { e.printStackTrace(); } } } /** * 根据原Sql语句获取对应的查询总记录数的Sql语句 * * @param sql * @return */ private String getCountSql(String sql) { return "select count(*) from (" + sql + ") as countRecord"; } }
所需要的工具类ReflectUtil类,以及Page类如下
package dwz.common.util; import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import org.apache.commons.lang3.StringUtils; import org.apache.commons.beanutils.BeanUtils; import org.springframework.util.Assert; /** * 利用反射进行操作的一个工具类 */ public class ReflectUtil { /** * 利用反射获取指定对象的指定属性 * * @param obj * 目标对象 * @param fieldName * 目标属性 * @return 目标属性的值 */ public static Object getFieldValue(Object obj, String fieldName) { Object result = null; Field field = ReflectUtil.getField(obj, fieldName); if (field != null) { field.setAccessible(true); try { result = field.get(obj); } catch (IllegalArgumentException e) { e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } } return result; } /** * 利用反射获取指定对象里面的指定属性 * * @param obj * 目标对象 * @param fieldName * 目标属性 * @return 目标字段 */ private static Field getField(Object obj, String fieldName) { Field field = null; for (Class<?> clazz = obj.getClass(); clazz != Object.class; clazz = clazz .getSuperclass()) { try { field = clazz.getDeclaredField(fieldName); break; } catch (NoSuchFieldException e) { // 这里不用做处理,子类没有该字段可能对应的父类有,都没有就返回null。 } } return field; } /** * 利用反射设置指定对象的指定属性为指定的值 * * @param obj * 目标对象 * @param fieldName * 目标属性 * @param fieldValue * 目标值 */ public static void setFieldValue(Object obj, String fieldName, String fieldValue) { Field field = ReflectUtil.getField(obj, fieldName); if (field != null) { try { field.setAccessible(true); field.set(obj, fieldValue); } catch (IllegalArgumentException e) { e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } } } /** * 两者属性名一致时,拷贝source里的属性到dest里 * * @return void * @throws IllegalArgumentException * @throws IllegalAccessException * @throws InvocationTargetException */ @SuppressWarnings("unchecked") public static void copyPorperties(Object dest, Object source) throws IllegalAccessException, IllegalArgumentException, InvocationTargetException{ Class srcCla = source.getClass(); Field[] fsF = srcCla.getDeclaredFields(); for (Field s : fsF) { String name = s.getName(); Object srcObj = invokeGetterMethod(source, name); try { BeanUtils.setProperty(dest, name, srcObj); } catch (Exception e){ e.printStackTrace(); } } } /** * 调用Getter方法. * @throws InvocationTargetException * @throws IllegalArgumentException * @throws IllegalAccessException */ public static Object invokeGetterMethod(Object target, String propertyName) throws IllegalAccessException, IllegalArgumentException, InvocationTargetException { String getterMethodName = "get" + StringUtils.capitalize(propertyName); return invokeMethod(target, getterMethodName, new Class[] {}, new Object[] {}); } /** * 直接调用对象方法, 无视private/protected修饰符. * @throws InvocationTargetException * @throws IllegalArgumentException * @throws IllegalAccessException */ public static Object invokeMethod(final Object object, final String methodName, final Class<?>[] parameterTypes, final Object[] parameters) throws IllegalAccessException, IllegalArgumentException, InvocationTargetException{ Method method = getDeclaredMethod(object, methodName, parameterTypes); if (method == null) { throw new IllegalArgumentException("Could not find method [" + methodName + "] parameterType " + parameterTypes + " on target [" + object + "]"); } method.setAccessible(true); return method.invoke(object, parameters); } /** * 循环向上转型, 获取对象的DeclaredMethod. * * 如向上转型到Object仍无法找到, 返回null. */ protected static Method getDeclaredMethod(Object object, String methodName, Class<?>[] parameterTypes) { Assert.notNull(object, "object不能为空"); for (Class<?> superClass = object.getClass(); superClass != Object.class; superClass = superClass .getSuperclass()) { try{ return superClass.getDeclaredMethod(methodName, parameterTypes); } catch (NoSuchMethodException e) {// NOSONAR // Method不在当前类定义,继续向上转型 } } return null; } }
page类,可以复写toString方法
package dwz.common.mybatis; import java.util.List; /** * 对分页的基本数据进行封装 */ public class Page<T>{ private int pageNum = 1;//页码,默认是第一页 private int pageSize = 5;//每页显示的记录数,默认是5 private int totalRecord;//总记录数 private int total;//总记录数 private int totalPage;//总页数 private List<T> results;//对应的当前页记录 public int getTotal() { return total; } public void setTotal(int total) { this.total = total; } public int getPageNum() { return pageNum; } public void setPageNum(int pageNum) { this.pageNum = pageNum; } public int getPageSize() { return pageSize; } public void setPageSize(int pageSize) { this.pageSize = pageSize; } public int getTotalRecord() { return totalRecord; } public void setTotalRecord(int totalRecord) { this.totalRecord = totalRecord; this.total=totalRecord; int totalPage = totalRecord % pageSize == 0 ? totalRecord / pageSize : totalRecord / pageSize + 1; this.setTotalPage(totalPage); } public int getTotalPage() { return totalPage; } public void setTotalPage(int totalPage) { this.totalPage = totalPage; } public List<T> getResults() { if(null != results && results.size() == 0){ return null; } return results; } public void setResults(List<T> results) { this.results = results; } }
最后在mybatis-config.xml配置中将该插件给配置进去,并对databaseType进行赋值。
<!-- 配置分页插件 --> <bean id="pagePlugin" class="dwz.common.mybatis.PageInterceptor"> <property name="properties"> <props> <prop key="databaseType">mysql</prop> </props> </property> </bean>