mybatis 分页插件PageInterceptor

匿名 (未验证) 提交于 2019-12-02 22:56:40

下面是继承了Interceptor的插件类

 package dwz.common.mybatis;  import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.util.List; import java.util.Properties;  import javax.xml.bind.PropertyException; import org.apache.commons.lang3.StringUtils; import org.apache.ibatis.binding.MapperMethod; import org.apache.ibatis.executor.parameter.ParameterHandler; import org.apache.ibatis.executor.statement.RoutingStatementHandler; import org.apache.ibatis.executor.statement.StatementHandler; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.mapping.ParameterMapping; import org.apache.ibatis.plugin.Interceptor; import org.apache.ibatis.plugin.Intercepts; import org.apache.ibatis.plugin.Invocation; import org.apache.ibatis.plugin.Plugin; import org.apache.ibatis.plugin.Signature; import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;  import dwz.common.mybatis.Page; import dwz.common.util.ReflectUtil;  @Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class }) }) @SuppressWarnings("rawtypes") public class PageInterceptor implements Interceptor {  	private static String databaseType ="";// 数据库类型,不同的数据库有不同的分页方法 	/** 	 * 拦截后要执行的方法 	 */ 	public Object intercept(Invocation invocation) throws Throwable {  		RoutingStatementHandler handler = (RoutingStatementHandler) invocation 				.getTarget(); 		StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate"); 		BoundSql boundSql = delegate.getBoundSql(); 		Object params = boundSql.getParameterObject(); 		Page page = null; 		if (params instanceof Page) { 			page = (Page) params; 		} else if (params instanceof MapperMethod.ParamMap) { 			MapperMethod.ParamMap paramMap = (MapperMethod.ParamMap) params; 			for (Object key : paramMap.keySet()) { 				if (paramMap.get(key) instanceof Page) { 					page = (Page) paramMap.get(key); 					break; 				} 			} 		}  		if (page != null) { 			MappedStatement mappedStatement = (MappedStatement) ReflectUtil 					.getFieldValue(delegate, "mappedStatement"); 			Connection connection = (Connection) invocation.getArgs()[0]; 			String sql = boundSql.getSql(); 			this.setTotalRecord(page, (MapperMethod.ParamMap) params, 					mappedStatement, connection); 			String pageSql = this.getPageSql(page, sql); 			ReflectUtil.setFieldValue(boundSql, "sql", pageSql); 		} 		return invocation.proceed(); 	}  	/** 	 * 拦截器对应的封装原始对象的方法 	 */ 	public Object plugin(Object target) { 		return Plugin.wrap(target, this); 	}  	public void setProperties(Properties p) { 		databaseType = p.getProperty("databaseType"); 		if (StringUtils.isEmpty(databaseType)) { 			try { 				throw new PropertyException("databaseType is not found!"); 			} catch (PropertyException e) { 				e.printStackTrace(); 			} 		}  	} 	private String getPageSql(Page<?> page, String sql) { 		StringBuffer sqlBuffer = new StringBuffer(sql); 		if ("mysql".equalsIgnoreCase(databaseType)) { 			return getMysqlPageSql(page, sqlBuffer); 		} else if ("oracle".equalsIgnoreCase(databaseType)) { 			return getOraclePageSql(page, sqlBuffer); 		} else if ("sqlserver".equalsIgnoreCase(databaseType)) { 			return getSqlserverPageSql(page, sqlBuffer); 		} 		return sqlBuffer.toString(); 	}  	private String getSqlserverPageSql(Page<?> page, StringBuffer sqlBuffer) { 		// 计算第一条记录的位置,Sqlserver中记录的位置是从0开始的。 		int startRowNum = (page.getPageNum() - 1) * page.getPageSize() + 1; 		int endRowNum = startRowNum + page.getPageSize(); 		String sql = "select appendRowNum.row,* from (select ROW_NUMBER() OVER (order by (select 0)) AS row,* from (" 				+ sqlBuffer.toString() 				+ ") as innerTable" 				+ ")as appendRowNum where appendRowNum.row >= " 				+ startRowNum 				+ " AND appendRowNum.row <= " + endRowNum; 		return sql; 	}  	private String getMysqlPageSql(Page<?> page, StringBuffer sqlBuffer) { 		// 计算第一条记录的位置,Mysql中记录的位置是从0开始的。 		int offset = (page.getPageNum() - 1) * page.getPageSize(); 		sqlBuffer.append(" limit ").append(offset).append(",").append(page.getPageSize()); 		return sqlBuffer.toString(); 	}  	private String getOraclePageSql(Page<?> page, StringBuffer sqlBuffer) { 		// 计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的 		int offset = (page.getPageNum() - 1) * page.getPageSize() + 1; 		sqlBuffer.insert(0, "select u.*, rownum r from (").append(") u where rownum < ") 			.append(offset + page.getPageSize()); 		sqlBuffer.insert(0, "select * from (").append(") where r >= ").append(offset); 		return sqlBuffer.toString(); 	}  	/** 	 * 给当前的参数对象page设置总记录数 	 *  	 * @param page 	 *            Mapper映射语句对应的参数对象 	 * @param mappedStatement 	 *            Mapper映射语句 	 * @param connection 	 */ 	private void setTotalRecord(Page<?> page, MapperMethod.ParamMap params, 			MappedStatement mappedStatement, Connection connection) { 		BoundSql boundSql = mappedStatement.getBoundSql(params); 		String sql = boundSql.getSql(); 		String countSql = this.getCountSql(sql); 		List<ParameterMapping> parameterMappings = boundSql.getParameterMappings(); 		BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql,parameterMappings, params); 		ParameterHandler parameterHandler = new DefaultParameterHandler( 				mappedStatement, params, countBoundSql); 		PreparedStatement pstmt = null; 		ResultSet rs = null; 		try { 			pstmt = connection.prepareStatement(countSql); 			parameterHandler.setParameters(pstmt); 			rs = pstmt.executeQuery(); 			if (rs.next()) { 				int totalRecord = rs.getInt(1); 				page.setTotalRecord(totalRecord); 			} 		} catch (SQLException e) { 			e.printStackTrace(); 		} finally { 			try { 				if (rs != null)rs.close(); 				if (pstmt != null)pstmt.close(); 			} catch (SQLException e) { 				e.printStackTrace(); 			} 		} 	}  	/** 	 * 根据原Sql语句获取对应的查询总记录数的Sql语句 	 *  	 * @param sql 	 * @return 	 */ 	private String getCountSql(String sql) { 		return "select count(*) from (" + sql + ") as countRecord"; 	}  } 

所需要的工具类ReflectUtil类,以及Page类如下

 package dwz.common.util;  import java.lang.reflect.Field; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method;  import org.apache.commons.lang3.StringUtils; import org.apache.commons.beanutils.BeanUtils; import org.springframework.util.Assert;   /**  * 利用反射进行操作的一个工具类  */ public class ReflectUtil { 	/** 	 * 利用反射获取指定对象的指定属性 	 *  	 * @param obj 	 *            目标对象 	 * @param fieldName 	 *            目标属性 	 * @return 目标属性的值 	 */ 	public static Object getFieldValue(Object obj, String fieldName) { 		Object result = null; 		Field field = ReflectUtil.getField(obj, fieldName); 		if (field != null) { 			field.setAccessible(true); 			try { 				result = field.get(obj); 			} catch (IllegalArgumentException e) { 				e.printStackTrace(); 			} catch (IllegalAccessException e) { 				e.printStackTrace(); 			} 		} 		return result; 	}  	/** 	 * 利用反射获取指定对象里面的指定属性 	 *  	 * @param obj 	 *            目标对象 	 * @param fieldName 	 *            目标属性 	 * @return 目标字段 	 */ 	private static Field getField(Object obj, String fieldName) { 		Field field = null; 		for (Class<?> clazz = obj.getClass(); clazz != Object.class; clazz = clazz 				.getSuperclass()) { 			try { 				field = clazz.getDeclaredField(fieldName); 				break; 			} catch (NoSuchFieldException e) { 				// 这里不用做处理,子类没有该字段可能对应的父类有,都没有就返回null。 			} 		} 		return field; 	}  	/** 	 * 利用反射设置指定对象的指定属性为指定的值 	 *  	 * @param obj 	 *            目标对象 	 * @param fieldName 	 *            目标属性 	 * @param fieldValue 	 *            目标值 	 */ 	public static void setFieldValue(Object obj, String fieldName, 			String fieldValue) { 		Field field = ReflectUtil.getField(obj, fieldName); 		if (field != null) { 			try { 				field.setAccessible(true); 				field.set(obj, fieldValue); 			} catch (IllegalArgumentException e) { 				e.printStackTrace(); 			} catch (IllegalAccessException e) { 				e.printStackTrace(); 			} 		} 	}  	/** 	 * 两者属性名一致时,拷贝source里的属性到dest里 	 *  	 * @return void 	 * @throws IllegalArgumentException  	 * @throws IllegalAccessException 	 * @throws InvocationTargetException 	 */ 	@SuppressWarnings("unchecked") 	public static void copyPorperties(Object dest, Object source) throws IllegalAccessException, IllegalArgumentException, InvocationTargetException{ 		Class srcCla = source.getClass(); 		Field[] fsF = srcCla.getDeclaredFields();  		for (Field s : fsF) 		{ 			String name = s.getName(); 			Object srcObj = invokeGetterMethod(source, name); 			try 			{ 				BeanUtils.setProperty(dest, name, srcObj); 			} 			catch (Exception e){ 				e.printStackTrace(); 			}  		} 	}  	/** 	 * 调用Getter方法. 	 * @throws InvocationTargetException  	 * @throws IllegalArgumentException  	 * @throws IllegalAccessException  	 */ 	public static Object invokeGetterMethod(Object target, String propertyName) throws IllegalAccessException, IllegalArgumentException, InvocationTargetException 	{ 		String getterMethodName = "get" + StringUtils.capitalize(propertyName); 		return invokeMethod(target, getterMethodName, new Class[] {}, 				new Object[] {}); 	}  	/** 	 * 直接调用对象方法, 无视private/protected修饰符. 	 * @throws InvocationTargetException  	 * @throws IllegalArgumentException  	 * @throws IllegalAccessException  	 */ 	public static Object invokeMethod(final Object object, 			final String methodName, final Class<?>[] parameterTypes, 			final Object[] parameters) throws IllegalAccessException, IllegalArgumentException, InvocationTargetException{ 		Method method = getDeclaredMethod(object, methodName, parameterTypes); 		if (method == null) 		{ 			throw new IllegalArgumentException("Could not find method [" 					+ methodName + "] parameterType " + parameterTypes 					+ " on target [" + object + "]"); 		}  		method.setAccessible(true); 		return method.invoke(object, parameters); 	}  	/** 	 * 循环向上转型, 获取对象的DeclaredMethod. 	 *  	 * 如向上转型到Object仍无法找到, 返回null. 	 */ 	protected static Method getDeclaredMethod(Object object, String methodName, 			Class<?>[] parameterTypes) 	{ 		Assert.notNull(object, "object不能为空");  		for (Class<?> superClass = object.getClass(); superClass != Object.class; superClass = superClass 				.getSuperclass()) 		{ 			try{ 				return superClass.getDeclaredMethod(methodName, parameterTypes); 			} 			catch (NoSuchMethodException e) 			{// NOSONAR 				// Method不在当前类定义,继续向上转型 			} 		} 		return null; 	} }  

page类,可以复写toString方法

 package dwz.common.mybatis;  import java.util.List;  /**  * 对分页的基本数据进行封装  */ public class Page<T>{     private int pageNum = 1;//页码,默认是第一页     private int pageSize = 5;//每页显示的记录数,默认是5     private int totalRecord;//总记录数     private int total;//总记录数     private int totalPage;//总页数     private List<T> results;//对应的当前页记录           public int getTotal() { 		return total; 	}  	public void setTotal(int total) { 		this.total = total; 	}  	public int getPageNum() {         return pageNum;     }      public void setPageNum(int pageNum) {         this.pageNum = pageNum;     }      public int getPageSize() {         return pageSize;     }      public void setPageSize(int pageSize) {         this.pageSize = pageSize;     }      public int getTotalRecord() {         return totalRecord;     }      public void setTotalRecord(int totalRecord) {         this.totalRecord = totalRecord;         this.total=totalRecord;         int totalPage = totalRecord % pageSize == 0 ? totalRecord / pageSize : totalRecord / pageSize + 1;         this.setTotalPage(totalPage);     }      public int getTotalPage() {         return totalPage;     }      public void setTotalPage(int totalPage) {         this.totalPage = totalPage;     }      public List<T> getResults() {     	if(null != results && results.size() == 0){             return null;     	}         return results;     }      public void setResults(List<T> results) {         this.results = results;     }    }

最后在mybatis-config.xml配置中将该插件给配置进去,并对databaseType进行赋值。

 <!-- 配置分页插件 --> 	  <bean id="pagePlugin" class="dwz.common.mybatis.PageInterceptor"> 		<property name="properties"> 			<props> 				<prop key="databaseType">mysql</prop> 			</props> 		</property> 	</bean>

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!