修复postgresql下扩展条件类型不匹配的问题

This commit is contained in:
mazhicheng 2020-03-21 15:28:52 +08:00
parent 7917f3bc4f
commit 9a884528b1
11 changed files with 131 additions and 76 deletions

View File

@ -1,7 +1,7 @@
package com.diboot.core.starter;
import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.extension.toolkit.JdbcUtils;
import com.diboot.core.util.ContextHelper;
import com.diboot.core.util.S;
import com.diboot.core.util.SqlExecutor;
import com.diboot.core.util.V;
@ -28,7 +28,6 @@ public class SqlHandler {
// 数据字典SQL
private static final String DICTIONARY_SQL = "SELECT id FROM ${SCHEMA}.dictionary WHERE id=0";
private static final String MYBATIS_PLUS_SCHEMA_CONFIG = "mybatis-plus.global-config.db-config.schema";
private static String dbType;
private static String CURRENT_SCHEMA = null;
private static Environment environment;
@ -38,8 +37,6 @@ public class SqlHandler {
*/
public static void init(Environment env) {
environment = env;
String jdbcUrl = getJdbcUrl(environment);
dbType = extractDatabaseType(jdbcUrl);
}
/***
@ -47,9 +44,8 @@ public class SqlHandler {
* @return
*/
public static void initBootstrapSql(Class inst, Environment environment, String module){
if(dbType == null){
init(environment);
}
init(environment);
String dbType = getDbType();
String sqlPath = "META-INF/sql/init-"+module+"-"+dbType+".sql";
extractAndExecuteSqls(inst, sqlPath);
}
@ -125,10 +121,10 @@ public class SqlHandler {
sqlStatement = clearComments(sqlStatement);
// 替换sqlStatement中的变量{SCHEMA}
if(sqlStatement.contains("${SCHEMA}")){
if(dbType.equals(DbType.SQL_SERVER.getDb())){
if(getDbType().equals(DbType.SQL_SERVER.getDb())){
sqlStatement = S.replace(sqlStatement, "${SCHEMA}", getSqlServerCurrentSchema());
}
else if(dbType.equals(DbType.ORACLE.getDb())){
else if(getDbType().equals(DbType.ORACLE.getDb())){
sqlStatement = S.replace(sqlStatement, "${SCHEMA}", getOracleCurrentSchema());
}
else{
@ -181,23 +177,6 @@ public class SqlHandler {
return lines;
}
/**
* 获取JDBC url
* @param environment
* @return
*/
public static String getJdbcUrl(Environment environment){
String jdbcUrl = environment.getProperty("spring.datasource.url");
if(jdbcUrl == null){
String master = environment.getProperty("spring.datasource.dynamic.primary");
jdbcUrl = environment.getProperty("spring.datasource.dynamic.datasource."+master+".url");
if(jdbcUrl == null){
log.warn("无法获取 datasource url 配置,请检查!");
}
}
return jdbcUrl;
}
/***
* 剔除SQL中的注释提取可执行的实际SQL
* @param inputSql
@ -242,20 +221,6 @@ public class SqlHandler {
return inputSql;
}
/**
* 提取数据库类型
* @param jdbcUrl
* @return
*/
public static String extractDatabaseType(String jdbcUrl){
DbType dbType = JdbcUtils.getDbType(jdbcUrl);
String dbName = dbType.getDb();
if(dbName.startsWith(DbType.SQL_SERVER.getDb()) && !dbName.equals(DbType.SQL_SERVER.getDb())){
dbName = DbType.SQL_SERVER.getDb();
}
return dbName;
}
//SQL Server查询当前schema
public static final String SQL_DEFAULT_SCHEMA = "SELECT DISTINCT default_schema_name FROM sys.database_principals where default_schema_name is not null AND name!='guest'";
/**
@ -330,6 +295,6 @@ public class SqlHandler {
* @return
*/
public static String getDbType(){
return dbType;
return ContextHelper.getDatabaseType();
}
}

View File

@ -10,9 +10,9 @@ import com.diboot.core.binding.binder.EntityBinder;
import com.diboot.core.binding.binder.EntityListBinder;
import com.diboot.core.binding.binder.FieldBinder;
import com.diboot.core.binding.parser.BindAnnotationGroup;
import com.diboot.core.binding.parser.ParserCache;
import com.diboot.core.binding.parser.ConditionManager;
import com.diboot.core.binding.parser.FieldAnnotation;
import com.diboot.core.binding.parser.ParserCache;
import com.diboot.core.entity.Dictionary;
import com.diboot.core.service.DictionaryService;
import com.diboot.core.util.BeanUtils;

View File

@ -5,17 +5,15 @@ import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.extension.service.IService;
import com.diboot.core.binding.parser.MiddleTable;
import com.diboot.core.config.BaseConfig;
import com.diboot.core.exception.BusinessException;
import com.diboot.core.service.BaseService;
import com.diboot.core.util.BeanUtils;
import com.diboot.core.util.IGetter;
import com.diboot.core.util.S;
import com.diboot.core.util.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.lang.reflect.Field;
import java.math.BigDecimal;
import java.util.*;
/**
* 关系绑定Binder父类
@ -52,6 +50,8 @@ public abstract class BaseBinder<T> {
*/
protected MiddleTable middleTable;
protected Class<T> referencedEntityClass;
/**
* join连接条件指定当前VO的取值方法和关联entity的取值方法
* @param annoObjectFkGetter 当前VO的取值方法
@ -77,27 +77,27 @@ public abstract class BaseBinder<T> {
}
public BaseBinder<T> andEQ(String fieldName, Object value){
queryWrapper.eq(S.toSnakeCase(fieldName), formatValue(value));
queryWrapper.eq(S.toSnakeCase(fieldName), formatValue(fieldName, value));
return this;
}
public BaseBinder<T> andNE(String fieldName, Object value){
queryWrapper.ne(S.toSnakeCase(fieldName), formatValue(value));
queryWrapper.ne(S.toSnakeCase(fieldName), formatValue(fieldName, value));
return this;
}
public BaseBinder<T> andGT(String fieldName, Object value){
queryWrapper.gt(S.toSnakeCase(fieldName), formatValue(value));
queryWrapper.gt(S.toSnakeCase(fieldName), formatValue(fieldName, value));
return this;
}
public BaseBinder<T> andGE(String fieldName, Object value){
queryWrapper.ge(S.toSnakeCase(fieldName), formatValue(value));
queryWrapper.ge(S.toSnakeCase(fieldName), formatValue(fieldName, value));
return this;
}
public BaseBinder<T> andLT(String fieldName, Object value){
queryWrapper.lt(S.toSnakeCase(fieldName), formatValue(value));
queryWrapper.lt(S.toSnakeCase(fieldName), formatValue(fieldName, value));
return this;
}
public BaseBinder<T> andLE(String fieldName, Object value){
queryWrapper.le(S.toSnakeCase(fieldName), formatValue(value));
queryWrapper.le(S.toSnakeCase(fieldName), formatValue(fieldName, value));
return this;
}
public BaseBinder<T> andIsNotNull(String fieldName){
@ -109,11 +109,11 @@ public abstract class BaseBinder<T> {
return this;
}
public BaseBinder<T> andBetween(String fieldName, Object begin, Object end){
queryWrapper.between(S.toSnakeCase(fieldName), formatValue(begin), formatValue(end));
queryWrapper.between(S.toSnakeCase(fieldName), formatValue(fieldName, begin), formatValue(fieldName, end));
return this;
}
public BaseBinder<T> andLike(String fieldName, String value){
queryWrapper.like(S.toSnakeCase(fieldName), formatValue(value));
queryWrapper.like(S.toSnakeCase(fieldName), formatValue(fieldName, value));
return this;
}
public BaseBinder<T> andIn(String fieldName, Collection valueList){
@ -125,11 +125,11 @@ public abstract class BaseBinder<T> {
return this;
}
public BaseBinder<T> andNotBetween(String fieldName, Object begin, Object end){
queryWrapper.notBetween(S.toSnakeCase(fieldName), formatValue(begin), formatValue(end));
queryWrapper.notBetween(S.toSnakeCase(fieldName), formatValue(fieldName, begin), formatValue(fieldName, end));
return this;
}
public BaseBinder<T> andNotLike(String fieldName, String value){
queryWrapper.notLike(S.toSnakeCase(fieldName), formatValue(value));
queryWrapper.notLike(S.toSnakeCase(fieldName), formatValue(fieldName, value));
return this;
}
public BaseBinder<T> andApply(String applySql){
@ -209,13 +209,48 @@ public abstract class BaseBinder<T> {
/**
* 格式化条件值
* @param value
* @param fieldName 属性名
* @param value
* @return
*/
private Object formatValue(Object value){
private Object formatValue(String fieldName, Object value){
if(value instanceof String && S.contains((String)value, "'")){
value = S.replace((String)value, "'", "");
return S.replace((String)value, "'", "");
}
// 转型
if(this.referencedEntityClass != null){
Field field = BeanUtils.extractField(this.referencedEntityClass, S.toLowerCaseCamel(fieldName));
if(field != null){
String valueStr = S.valueOf(value);
String type = field.getGenericType().getTypeName();
if(Integer.class.getName().equals(type)){
return Integer.parseInt(valueStr);
}
else if(Long.class.getName().equals(type)){
return Long.parseLong(valueStr);
}
else if(Double.class.getName().equals(type)){
return Double.parseDouble(valueStr);
}
else if(BigDecimal.class.getName().equals(type)){
return new BigDecimal(valueStr);
}
else if(Float.class.getName().equals(type)){
return Float.parseFloat(valueStr);
}
else if(Boolean.class.getName().equals(type)){
return V.isTrue(valueStr);
}
else if(type.contains(Date.class.getSimpleName())){
return D.fuzzyConvert(valueStr);
}
}
}
else{
throw new BusinessException("dddd");
}
return value;
}
}

View File

@ -42,6 +42,7 @@ public class EntityBinder<T> extends BaseBinder<T> {
this.referencedService = referencedService;
this.annoObjectList = voList;
this.queryWrapper = new QueryWrapper<T>();
this.referencedEntityClass = BeanUtils.getGenericityClass(referencedService, 1);
}
/***

View File

@ -31,6 +31,7 @@ public class EntityListBinder<T> extends EntityBinder<T> {
this.referencedService = serviceInstance;
this.annoObjectList = voList;
this.queryWrapper = new QueryWrapper<T>();
this.referencedEntityClass = BeanUtils.getGenericityClass(referencedService, 1);
}
@Override

View File

@ -35,6 +35,7 @@ public class FieldBinder<T> extends BaseBinder<T> {
this.referencedService = serviceInstance;
this.annoObjectList = voList;
this.queryWrapper = new QueryWrapper<T>();
this.referencedEntityClass = BeanUtils.getGenericityClass(referencedService, 1);
}
/***

View File

@ -1,5 +1,6 @@
package com.diboot.core.service.impl;
import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.core.conditions.Wrapper;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import com.baomidou.mybatisplus.core.metadata.IPage;
@ -86,12 +87,20 @@ public class BaseServiceImpl<M extends BaseCrudMapper<T>, T> extends ServiceImpl
@Override
@Transactional(rollbackFor = Exception.class)
public boolean createEntities(Collection entityList){
public boolean createEntities(Collection<T> entityList){
if(V.isEmpty(entityList)){
return false;
}
// 批量插入
return super.saveBatch(entityList, BaseConfig.getBatchSize());
if(DbType.SQL_SERVER.getDb().equalsIgnoreCase(ContextHelper.getDatabaseType())){
for(T entity : entityList){
createEntity(entity);
}
return true;
}
else{
// 批量插入
return super.saveBatch(entityList, BaseConfig.getBatchSize());
}
}
@Override

View File

@ -87,23 +87,19 @@ public class DictionaryServiceImpl extends BaseServiceImpl<DictionaryMapper, Dic
@Override
@Transactional(rollbackFor = Exception.class)
public boolean createDictAndChildren(DictionaryVO dictVO) {
Dictionary dictionary = (Dictionary)dictVO;
Dictionary dictionary = dictVO;
if(!super.createEntity(dictionary)){
log.warn("新建数据字典定义失败type="+dictVO.getType());
return false;
}
List<Dictionary> children = dictVO.getChildren();
if(V.notEmpty(children)){
boolean success = true;
for(Dictionary dict : children){
dict.setParentId(dictionary.getId());
dict.setType(dictionary.getType());
boolean insertOK = super.createEntity(dict);
if(!insertOK){
log.warn("dictionary插入数据字典失败请检查");
success = false;
}
}
// 批量保存
boolean success = super.saveBatch(children);
if(!success){
String errorMsg = "新建数据字典子项失败type="+dictVO.getType();
log.warn(errorMsg);
@ -117,7 +113,7 @@ public class DictionaryServiceImpl extends BaseServiceImpl<DictionaryMapper, Dic
@Transactional(rollbackFor = Exception.class)
public boolean updateDictAndChildren(DictionaryVO dictVO) {
//将DictionaryVO转化为Dictionary
Dictionary dictionary = (Dictionary)dictVO;
Dictionary dictionary = dictVO;
if(!super.updateEntity(dictionary)){
log.warn("更新数据字典定义失败type="+dictVO.getType());
return false;
@ -151,7 +147,7 @@ public class DictionaryServiceImpl extends BaseServiceImpl<DictionaryMapper, Dic
if(!dictItemIds.contains(dict.getId())){
if(!super.deleteEntity(dict.getId())){
log.warn("删除子数据字典失败itemName="+dict.getItemName());
throw new RuntimeException();
throw new BusinessException(Status.FAIL_EXCEPTION, "删除字典子项异常");
}
}
}

View File

@ -1,16 +1,22 @@
package com.diboot.core.util;
import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import com.baomidou.mybatisplus.extension.service.IService;
import com.baomidou.mybatisplus.extension.toolkit.JdbcUtils;
import com.diboot.core.config.Cons;
import com.diboot.core.entity.BaseEntity;
import com.diboot.core.service.BaseService;
import org.apache.ibatis.session.SqlSessionFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.annotation.Lazy;
import org.springframework.core.env.Environment;
import org.springframework.stereotype.Component;
import org.springframework.web.context.ContextLoader;
@ -49,6 +55,10 @@ public class ContextHelper implements ApplicationContextAware {
* 存储主键字段非id的Entity
*/
private static Map<String, String> PK_NID_ENTITY_CACHE = new ConcurrentHashMap<>();
/**
* 数据库类型
*/
private static String DATABASE_TYPE = null;
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
@ -132,6 +142,7 @@ public class ContextHelper implements ApplicationContextAware {
* @param entity
* @return
*/
@Deprecated
public static IService getIServiceByEntity(Class entity){
if(ENTITY_SERVICE_CACHE == null){
ENTITY_SERVICE_CACHE = new ConcurrentHashMap<>();
@ -203,4 +214,35 @@ public class ContextHelper implements ApplicationContextAware {
}
return PK_NID_ENTITY_CACHE.get(entity.getName());
}
/**
* 获取数据库类型
* @return
*/
public static String getDatabaseType(){
if(DATABASE_TYPE != null){
return DATABASE_TYPE;
}
Environment environment = getApplicationContext().getEnvironment();
String jdbcUrl = environment.getProperty("spring.datasource.url");
if(jdbcUrl == null){
String master = environment.getProperty("spring.datasource.dynamic.primary");
jdbcUrl = environment.getProperty("spring.datasource.dynamic.datasource."+master+".url");
}
if(jdbcUrl != null){
DbType dbType = JdbcUtils.getDbType(jdbcUrl);
DATABASE_TYPE = dbType.getDb();
}
else{
SqlSessionFactory sqlSessionFactory = getBean(SqlSessionFactory.class);
if(sqlSessionFactory != null){
DATABASE_TYPE = sqlSessionFactory.getConfiguration().getDatabaseId();
}
}
if(DATABASE_TYPE == null){
log.warn("无法识别数据库类型,请检查配置!");
}
return DATABASE_TYPE;
}
}

View File

@ -1,8 +1,6 @@
package com.diboot.core.util;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.env.Environment;
/**

View File

@ -8,6 +8,7 @@ import com.diboot.core.config.BaseConfig;
import com.diboot.core.entity.Dictionary;
import com.diboot.core.service.impl.DictionaryServiceImpl;
import com.diboot.core.util.BeanUtils;
import com.diboot.core.util.ContextHelper;
import com.diboot.core.util.V;
import com.diboot.core.vo.*;
import diboot.core.test.StartupApplication;
@ -233,6 +234,12 @@ public class BaseServiceTest {
public void testExist(){
boolean exists = dictionaryService.exists(Dictionary::getType, "GENDER");
Assert.assertTrue(exists);
}
@Test
public void testContextHelper(){
String database = ContextHelper.getDatabaseType();
System.out.println(database);
Assert.assertTrue(database.equals("mysql"));
}
}