mybatis拦截器实现数据库数据权限隔离方式

 更新时间:2024年11月06日 09:39:37   作者:不要停下脚步  
通过Mybatis拦截器,在执行SQL前添加条件实现数据权限隔离,特别是对于存在用户ID区分的表,拦截器会自动添加如user_id=#{userId}的条件,确保SQL在执行时只能操作指定用户的数据,此方法主要应用于Mybatis的四个阶段

原理

使用拦截器在mybatis 执行sql 之前 ,

将sql 后面加上指定的查询条件 

比如,你的表以user_id 作为区分 

那么你就需要在sql 拦截器中加上 user_id = #{userId} 的逻辑

实现

mybatis 拦截器的相关知识不再赘述 , 可以在mybatis 的四个阶段进行拦截

分别是 Execute , MappedStatment , ParamHanlder ,以及 ResultHandler

详细的每个阶段做什么事情 ,可以自行百度。

 @AuthFilter(userFiled = "user_id" , ignoreOrgFiled = true)
    Page getUserMsgPage(@Param("page")Page page , @Param("param") MsgUserRefDto param , @Param("loginId") String loginId , @Param("orderBy")String orderBy);

具体效果就是 , 我们希望上面的sql 在执行的时候 ,自动拼接上 and user_id = 1 ,去过滤指定用户的数据。

配置文件

@Configuration
@AutoConfigureAfter(PageHelperAutoConfiguration.class)
public class MybatisConfig {

    @Autowired
    private List<SqlSessionFactory> sqlSessionFactoryList;

    @PostConstruct
    void mybatisConfigurationCustomizer() {

        AuthInterceptor authInterceptor = new AuthInterceptor();
        sqlSessionFactoryList.forEach(o->{
            o.getConfiguration().addInterceptor(authInterceptor);
        });
    }
}

自定义注解

@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD , ElementType.TYPE})
@Documented
public @interface AuthFilter {

    String userFiled() default "userId";

    String orgFiled() default "orgId";

    boolean ignoreUserFiled() default false;

    boolean ignoreOrgFiled() default false;
}

具体拦截器逻辑

其中,GlobalHolder 就是每个系统中自己存储用户登录信息的容器 。

@Slf4j
@Component
@Intercepts({@Signature(
        type = Executor.class,
        method = "query",
        args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
), @Signature(
        type = Executor.class,
        method = "query",
        args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}
)})
public class AuthInterceptor implements Interceptor {

   private static final Map<Class<?>, Map<String, List<List<Class>>>> mapperCache = new ConcurrentHashMap();


    @Override
    public Object intercept(Invocation invocation) throws Throwable {

        Object[] args = invocation.getArgs();
        String id = ((MappedStatement)args[0]).getId();
        String clazzName = id.substring(0, id.lastIndexOf('.'));
        String mapperMethod = id.substring(id.lastIndexOf('.') + 1);

        Object[] paramArr = getParamArr(args[1]);
        Class<?> clazz = Class.forName(clazzName);

        Method method = getMethod(clazz, mapperMethod, paramArr);
        AuthFilter authFilter = method.getAnnotation(AuthFilter.class);


        // 如果方法没有加上注解正常执行 ,否则开始解析
        if (authFilter != null) {

            Map params = new HashMap();
            // 获取各个filed
            String orgFiled = authFilter.orgFiled();
            String userFiled = authFilter.userFiled();
            // 获取用户登录id 和 组织Id
            String orgId = GlobalHolder.getOrgId();
            String loginId = GlobalHolder.getLoginId();

            boolean ignoreOrgFiled = authFilter.ignoreOrgFiled();
            boolean ignoreUserFiled = authFilter.ignoreUserFiled();

            MappedStatement ms = (MappedStatement)args[0];
            Object parameter = args[1];
            BoundSql boundSql;
            if (args.length == 4) {
                boundSql = ms.getBoundSql(parameter);
            } else {
                boundSql = (BoundSql)args[5];
            }

            String sql = boundSql.getSql();

            // 添加组织编号
            if (!ignoreOrgFiled) {

                if(StringUtils.isNotEmpty(orgId)){
                    params.put(orgFiled , orgId);
                }else {
                    throw new IllegalStateException("用户未登录!");
                }

            }

            if (!ignoreUserFiled) {

                if(StringUtils.isNotEmpty(loginId)){
                    params.put(userFiled , loginId);
                }else {
                    throw new IllegalStateException("用户未登录!");
                }
            }

            if(params.size() > 0){
               String concatSql = contactConditions(wrapSql(sql) , params);
                log.info("添加后的sql为: {}" , concatSql);
                ReflectUtil.setFieldValue(boundSql, "sql", concatSql);
            }
        }
        return invocation.proceed();
    }


    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);

    }

    @Override
    public void setProperties(Properties properties) {
    }

    private String wrapSql(String sql){

        if(StringUtils.isNotEmpty(sql)){

            StringBuilder realSql = new StringBuilder();
            realSql.append("select * from ( ");
            realSql.append(sql);
            realSql.append(") a");

            return realSql.toString();
        }
        return sql;
    }

    /** 获取 mapper 相应 Method 反射类 */
    private Method getMethod(Class<?> clazz, String mapperMethod, Object[] paramArr) throws NoSuchMethodException, NoSuchFieldException, IllegalAccessException {
        // 1、查 mapper 接口缓存
        if (!mapperCache.containsKey(clazz)) // mapper 没有缓存, 就进行缓存
        {
            cacheMapper(clazz);
        }
        // 2、返回相应 method
        A:
        for (List<Class> paramList : mapperCache.get(clazz).get(mapperMethod)) {
            if (!paramList.isEmpty()) {
                for (int i = 0; i < paramArr.length; i++) { // 比较参数列表class
                    if (paramArr[i] != null)
                        if (!compareClass(paramList.get(i), paramArr[i].getClass())) continue A;
                }
                return clazz.getMethod(mapperMethod, paramList.toArray(new Class[paramList.size()]));
            }
        }
        return clazz.getMethod(mapperMethod); // 返回无参方法
    }

        /** 对 mapper 方法字段进行缓存 */
        private void cacheMapper(Class<?> clazz) {
            Map<String, List<List<Class>>> methodMap = new HashMap();
            for(Method method : clazz.getMethods()) {
                List<List<Class>> paramLists = methodMap.containsKey(method.getName()) ?
                        methodMap.get(method.getName()) : new ArrayList<List<Class>>();
                List<Class> paramClass = new ArrayList<Class>();
                for (Type type : method.getParameterTypes())
                {
                    paramClass.add((Class) type);
                }
                paramLists.add(paramClass);
                methodMap.put(method.getName(), paramLists);
            }
            mapperCache.put(clazz, methodMap);
        }

        /** class 比较 */
        private boolean compareClass(Class<?> returnType, Class<?> paramType) throws NoSuchFieldException, IllegalAccessException {
            if(returnType == paramType) {
                return true;
            }
            else if(returnType.isAssignableFrom(paramType)) { // 判断 paramType 是否为 returnType 子类或者实现类
                return true;
            }
            // 基本数据类型判断
            else if(returnType.isPrimitive()) { // paramType为包装类
                return returnType == paramType.getField("TYPE").get(null);
            }
            else if(paramType.isPrimitive()) { // returnType为包装类
                return paramType == returnType.getField("TYPE").get(null);
            }
            return false;
        }

    /**
     * 获取 mybatis 中 mapper 接口的参数列表的参数值
     * @param parameter
     * @return
     */
    private Object[] getParamArr(Object parameter) {
        Object[] paramArr = null;
        // mapper 接口中使用的是 paramMap, 传多个参数
        if(parameter instanceof MapperMethod.ParamMap)
        {
            Map map = ((Map) parameter);
            if(!map.isEmpty()) {
                StringBuilder builder = new StringBuilder();
                // 初始化 param_arr
                int size = map.size() >> 1;
                paramArr = new Object[size];
                for(int i = 1;i <= size;i ++)
                {
                    // mapper 接口中使用 param0 ~ paramN 命名参数
                    paramArr[i - 1] = map.get(builder.append("param").append(i).toString());
                    builder.setLength(0);
                }
            }
        }
        else if(parameter != null)
        {
            paramArr = new Object[1];
            paramArr[0] = parameter;
        }
        return paramArr;
    }


    private static String contactConditions(String sql, Map<String, Object> columnMap) {
        SQLStatementParser parser = SQLParserUtils.createSQLStatementParser(sql, JdbcUtils.MYSQL);
        List<SQLStatement> stmtList = parser.parseStatementList();
        SQLStatement stmt = stmtList.get(0);
        if (stmt instanceof SQLSelectStatement) {
            StringBuffer constraintsBuffer = new StringBuffer();
            Set<String> keys = columnMap.keySet();
            Iterator<String> keyIter = keys.iterator();
            if (keyIter.hasNext()) {
                String key = keyIter.next();
                constraintsBuffer.append(key).append(" = " + getSqlByClass(columnMap.get(key)));
            }
            while (keyIter.hasNext()) {
                String key = keyIter.next();
                constraintsBuffer.append(" AND ").append(key).append(" = " + getSqlByClass(columnMap.get(key)));
            }
            SQLExprParser constraintsParser = SQLParserUtils.createExprParser(constraintsBuffer.toString(), JdbcUtils.MYSQL);
            SQLExpr constraintsExpr = constraintsParser.expr();

            SQLSelectStatement selectStmt = (SQLSelectStatement) stmt;
            // 拿到SQLSelect
            SQLSelect sqlselect = selectStmt.getSelect();
            SQLSelectQueryBlock query = (SQLSelectQueryBlock) sqlselect.getQuery();
            SQLExpr whereExpr = query.getWhere();
            // 修改where表达式
            if (whereExpr == null) {
                query.setWhere(constraintsExpr);
            } else {
                SQLBinaryOpExpr newWhereExpr = new SQLBinaryOpExpr(whereExpr, SQLBinaryOperator.BooleanAnd, constraintsExpr);
                query.setWhere(newWhereExpr);
            }
            sqlselect.setQuery(query);
            return sqlselect.toString();

        }

        return sql;
    }

    private static String getSqlByClass(Object value){

        if(value instanceof Number){
            return value + "";
        }else if(value instanceof String){
            return "'" + value + "'";
        }

        return "'" + value.toString() + "'";
    }

}

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

相关文章

  • Mybatis select记录封装的实现

    Mybatis select记录封装的实现

    这篇文章主要介绍了Mybatis select记录封装的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2020-10-10
  • SpringBoot整合mybatis使用Druid做连接池的方式

    SpringBoot整合mybatis使用Druid做连接池的方式

    这篇文章主要介绍了SpringBoot整合mybatis使用Druid做连接池的方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教
    2023-08-08
  • Servlet实现代理文件下载功能

    Servlet实现代理文件下载功能

    这篇文章主要为大家详细介绍了Servlet实现代理文件下载功能,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2017-12-12
  • Java中的OkHttp使用教程

    Java中的OkHttp使用教程

    OkHttp是目前非常火的网络库,OKHttp与HttpClient类似,也是一个Http客户端,提供了对 HTTP/2 和 SPDY 的支持,并提供了连接池,GZIP 压缩和 HTTP 响应缓存功能,本文重点给大家介绍Java OkHttp使用,感兴趣的朋友一起看看吧
    2022-04-04
  • java爱心代码完整示例(脱单必备)

    java爱心代码完整示例(脱单必备)

    最近看到个好玩的,就是用代码实现爱心的形状,对于不懂编程的人来说,这是一个很好的玩的东西,这篇文章主要给大家介绍了关于java爱心代码的相关资料,需要的朋友可以参考下
    2023-07-07
  • Java无限级树(递归)超实用案例

    Java无限级树(递归)超实用案例

    下面小编就为大家带来一篇Java无限级树(递归)超实用案例。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧
    2016-11-11
  • Java由浅入深细数数组的操作下

    Java由浅入深细数数组的操作下

    数组对于每一门编程语言来说都是重要的数据结构之一,当然不同语言对数组的实现及处理也不尽相同。Java 语言中提供的数组是用来存储固定大小的同类型元素
    2022-04-04
  • Java中FilterInputStream和FilterOutputStream的用法详解

    Java中FilterInputStream和FilterOutputStream的用法详解

    这篇文章主要介绍了Java中FilterInputStream和FilterOutputStream的用法详解,这两个类分别用于封装输入和输出流,需要的朋友可以参考下
    2016-06-06
  • 查看Spring容器中bean的五种方法小结

    查看Spring容器中bean的五种方法小结

    近期在写Spring项目的时候,需要通过注解的形式去替代之前直接将Bean存放在Spring容器这种方式,以此来简化对于Bean对象的操作,这篇文章主要给大家介绍了关于如何查看Spring容器中bean的五种方法,需要的朋友可以参考下
    2024-05-05
  • java接口幂等性的实现方式

    java接口幂等性的实现方式

    本文介绍了在不同层面上实现Java接口幂等性的方法,包括使用幂等表、Nginx+Lua和Redis、以及SpringAOP,通过这些方法,可以确保接口在多次请求时只执行一次,避免重复处理和数据不一致,每种方法都有其适用场景和优势,通过实际测试验证了幂等性逻辑的有效性
    2025-01-01

最新评论