本人比较懒,在做一些小的项目的时候,看到数据库有很多的表,然后要一一对应给其生成实体类,我就想能不能有一个简单的方式进行生成实体类

当初在网上查了一些资料,然后使用了一下Hibernate的正向生成的功能,发现Hibernate 生成的有很多是不需要的就想着自己写一个生成简单实体类

的功能,这个是基于javaPoet这个开源项目,使用它提供的方法进行整合然后写的,

使用的工具是MyEclipse和oracle话不多说上代码


package com.personal.tool;

import java.io.File;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.math.BigDecimal;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Types;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Scanner;

import javax.lang.model.element.Modifier;
import javax.persistence.Column;
import javax.persistence.Entity;
import javax.persistence.Id;
import javax.persistence.Table;

import com.personal.factory.PersonalFactory;
import com.squareup.javapoet.AnnotationSpec;
import com.squareup.javapoet.FieldSpec;
import com.squareup.javapoet.JavaFile;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.TypeSpec;

@SuppressWarnings("all")
public class CreateClass { // 根据表名创建类

    /**
     * 根据数据库表名进行创建实体类
     * 
     * 自用修改需要修改的地方有: 1.修改获取连接的方法名 2.修改tableName所在的包路径
     * 
     * 
     * @param tableName
     * @param url
     */

    public static void createClass(Connection con, String tableName, String url) {// 传过来表名和完整的包路径
        PreparedStatement pre = null;
        ResultSet rs = null;
        try {
            String sql = "select * from " + tableName;// 根据表名组合成为sql语句
            pre = con.prepareStatement(sql);
            rs = pre.executeQuery();
            ResultSetMetaData res = rs.getMetaData();// 获取元数据
            DatabaseMetaData data = con.getMetaData();
            ResultSet primaryKeys = data.getPrimaryKeys(null, null, tableName);
            String key = null;
            while (primaryKeys.next()) {
                key = primaryKeys.getString("COLUMN_NAME");
                break;
            }
            if (key == null) {
                key = "1";
            }
            int count = res.getColumnCount();// 获取列数
            List<MethodSpec> methodlist = new ArrayList<MethodSpec>();
            List<FieldSpec> fieldList = new ArrayList<FieldSpec>();
            for (int i = 1; i <= count; i++) {
                String name = res.getColumnName(i).toLowerCase(); // 获取这列的名称
                int ctype = res.getColumnType(i); // 获取这列的数据类型
                if (!key.equalsIgnoreCase("1")) {
                    if (key.equalsIgnoreCase(name)) {
                        if (ctype == Types.VARCHAR) {
                            fieldList.add(field(String.class, name));
                            methodlist.add(getKeyMethod(String.class, name));
                            methodlist.add(setMethod(String.class, name));
                        } else if (ctype == Types.NUMERIC) {
                            if (res.getScale(i) > 0) {
                                fieldList.add(field(Double.class, name));
                                methodlist
                                        .add(getKeyMethod(Double.class, name));
                                methodlist.add(setMethod(Double.class, name));
                            } else {
                                fieldList.add(field(Integer.class, name));
                                methodlist
                                        .add(getKeyMethod(Integer.class, name));
                                methodlist.add(setMethod(Integer.class, name));
                            }
                        } else if (ctype == Types.DATE) {
                            fieldList.add(field(Date.class, name));
                            methodlist.add(getKeyMethod(Date.class, name));
                            methodlist.add(setMethod(Date.class, name));
                        }
                    } else {
                        if (ctype == Types.VARCHAR) {
                            fieldList.add(field(String.class, name));
                            methodlist.add(getMethod(String.class, name));
                            methodlist.add(setMethod(String.class, name));
                        } else if (ctype == Types.NUMERIC) {
                            if (res.getScale(i) > 0) {
                                fieldList.add(field(Double.class, name));
                                methodlist.add(getMethod(Double.class, name));
                                methodlist.add(setMethod(Double.class, name));
                            } else {
                                fieldList.add(field(Integer.class, name));
                                methodlist.add(getMethod(Integer.class, name));
                                methodlist.add(setMethod(Integer.class, name));
                            }
                        } else if (ctype == Types.DATE) {
                            fieldList.add(field(Date.class, name));
                            methodlist.add(getMethod(Date.class, name));
                            methodlist.add(setMethod(Date.class, name));
                        }
                    }
                } else {
                    if (ctype == Types.VARCHAR) {
                        fieldList.add(field(String.class, name));
                        methodlist.add(getMethod(String.class, name));
                        methodlist.add(setMethod(String.class, name));
                    } else if (ctype == Types.NUMERIC) {
                        if (res.getScale(i) > 0) {
                            fieldList.add(field(Double.class, name));
                            methodlist.add(getMethod(Double.class, name));
                            methodlist.add(setMethod(Double.class, name));
                        } else {
                            fieldList.add(field(Integer.class, name));
                            methodlist.add(getMethod(Integer.class, name));
                            methodlist.add(setMethod(Integer.class, name));
                        }
                    } else if (ctype == Types.DATE) {
                        fieldList.add(field(Date.class, name));
                        methodlist.add(getMethod(Date.class, name));
                        methodlist.add(setMethod(Date.class, name));
                    }
                }
            }
            savClass(tableName, methodlist, fieldList, url);
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            try {
                if (rs != null) {
                    rs.close();
                }
                if (pre != null) {
                    pre.close();
                }
                if (con != null) {
                    con.close();
                }
            } catch (Exception e) {
                // TODO: handle exception
            }
        }
    }

    private static <T> List<T> getSelect(Connection con, String sql,
            String className, Object... objects) {// 参数为
        // sql语句
        // 类的完全限定名
        // 不确定个数的参数
        List<T> list = new ArrayList<T>();
        PreparedStatement pr = null;
        ResultSet rs = null;
        try {
            Class<T> clazz = (Class<T>) Class.forName(className);// 根据类的地址进行反射出来实体类
            Field[] fi = clazz.getDeclaredFields(); // 获取类里面的所有属性
            Method me = null;
            pr = con.prepareStatement(sql);
            insertInfo(pr, objects);
            rs = pr.executeQuery();

            ResultSetMetaData res = rs.getMetaData();// 获取元数据
            int count = res.getColumnCount();// 获取列数
            while (rs.next()) {
                T obs = clazz.newInstance();// 创建一个反射类的对象
                for (int i = 1; i <= count; i++) { // 循环对数据库查询出来的第一个数据进行封装
                    String name = res.getColumnName(i); // 获取这列的名称
                    int ctype = res.getColumnType(i); // 获取这一列的数据类型
                    for (Field field : fi) { // 遍历根据反射查询出来的类
                        if (name.equalsIgnoreCase(field.getName())) {// 判断这个数据库查询出来的属性名称是不是和实体类里面的名称相同
                            String firs = field.getName().substring(0, 1)// 如果是则进行截取字符串把这个名称第一个字符进行截取然后转换为大写
                                    .toUpperCase();
                            String meth = "set" + firs// 然后拼接字符串把这个属性的set方法进行拼接出来组成为(set+大写属性首字符+属性后的名称=set方法)
                                    + field.getName().substring(1);
                            if (ctype == Types.INTEGER) { // 判断列的数据类型如果是则通过反射的set方法进行添加到这个实体类中
                                me = clazz.getMethod(meth, Integer.class);
                                me.invoke(obs, rs.getInt(i));
                                break;
                            } else if (ctype == Types.VARCHAR) {
                                me = clazz.getMethod(meth, String.class);
                                me.invoke(obs, rs.getString(i));
                                break;
                            } else if (ctype == Types.DATE) {
                                me = clazz.getMethod(meth, Date.class);
                                me.invoke(obs, rs.getDate(i));
                                break;
                            } else if (ctype == Types.NUMERIC) {
                                field.getGenericType();// 获取到这个属性的类型然后判断这个类型是什么然后根据类型创建方法
                                rs.getBigDecimal(i).intValue();
                                me = clazz.getMethod(meth, BigDecimal.class);
                                me.invoke(obs, rs.getBigDecimal(i));
                                break;
                            }
                        }
                    }
                }
                list.add(obs);// 所有属性的执行方法执行完毕的时候把这个实体类添加到List里面
            }
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
            try {
                if (rs != null) {
                    rs.close();
                }
                if (pr != null) {
                    pr.close();
                }
                if (con != null) {
                    con.close();
                }
            } catch (Exception e) {
                // TODO: handle exception
            }
        }
        return list;// 返回list
    }

    private static void insertInfo(PreparedStatement pre, Object... objects) {
        if (null == pre || null == objects)
            return;
        for (int i = 0; i < objects.length; i++) {
            try {
                pre.setObject(i + 1, objects[i]);
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
    }

    public static List<String> listTableName(Connection con) { // 查询出数据库这个用户的所有的表然后可以根据表进行循环生成对应的实体类
        List<OracleTableName> list = getSelect(con,
                "select table_name from user_tables",
                OracleTableName.class.getName());
        List<String> tableName = new ArrayList<String>();
        for (OracleTableName name : list) {
            tableName.add(name.getTable_name());
        }
        return tableName;
    }

    @SuppressWarnings("all")
    private static MethodSpec getMethod(Class clazz, String name) { // 设置属性的get方法
        String firs = name.substring(0, 1)// 如果是则进行截取字符串把这个名称第一个字符进行截取然后转换为大写
                .toUpperCase();
        String getName = "get" + firs + name.substring(1);
        MethodSpec getMethod = MethodSpec
                .methodBuilder(getName)
                .addModifiers(Modifier.PUBLIC)
                .returns(clazz)
                .addAnnotation(
                        AnnotationSpec.builder(Column.class)
                                .addMember("name", "$S", name)// 生成注解的属性名
                                .addMember("nullable", "$L", "false").build())// 生成注解这个字段是否可以为空
                .addStatement("return " + name).build();
        return getMethod;
    }

    private static MethodSpec getKeyMethod(Class clazz, String name) { // 设置属性的get方法
        String firs = name.substring(0, 1)// 如果是则进行截取字符串把这个名称第一个字符进行截取然后转换为大写
                .toUpperCase();
        String getName = "get" + firs + name.substring(1);
        MethodSpec getMethod = MethodSpec
                .methodBuilder(getName)
                .addModifiers(Modifier.PUBLIC)
                .returns(clazz)
                .addAnnotation(
                        AnnotationSpec.builder(Id.class)
                                .addMember("", "$S", name).build())

                .addStatement("return " + name).build();
        return getMethod;
    }

    private static MethodSpec setMethod(Class clazz, String name) {// 设置属性的set方法
        String firs = name.substring(0, 1)// 如果是则进行截取字符串把这个名称第一个字符进行截取然后转换为大写
                .toUpperCase();
        String setName = "set" + firs// 然后拼接字符串把这个属性的set方法进行拼接出来组成为(set+大写属性首字符+属性后的名称=set方法)
                + name.substring(1);
        String filds = "this." + name + "=" + name;
        MethodSpec setMethod = MethodSpec.methodBuilder(setName)
                .addModifiers(Modifier.PUBLIC).returns(void.class)
                .addParameter(clazz, name).addStatement(filds).build();
        return setMethod;
    }

    private static FieldSpec field(Class clazz, String name) {// 设置属性方法
        FieldSpec feid = FieldSpec.builder(clazz, name, Modifier.PRIVATE)
                .build();
        return feid;
    }

    private static void savClass(String tableName, List<MethodSpec> methodList, // 生成实体类方法
            List<FieldSpec> fieldList, String url) throws IOException {
        String tabName = tableName.substring(0, 1).toUpperCase();
        tabName = tabName + tableName.substring(1).toLowerCase();
        // 根据类名创建
        TypeSpec helloword = TypeSpec
                .classBuilder(tabName)
                // 创建类的修饰符和添加类的方法
                .addModifiers(Modifier.PUBLIC)

                .addMethods(methodList)
                .addAnnotation(
                        AnnotationSpec.builder(Entity.class)
                                .addMember("name", "$S", tabName).build())
                .addAnnotation(
                        AnnotationSpec.builder(Table.class)
                                .addMember("name", "$S", tabName).build())
                .addFields(fieldList).build();
        int index = url.lastIndexOf(".");

        TypeSpec dao = TypeSpec.interfaceBuilder(tabName + "Dao").build();
        JavaFile fil = JavaFile.builder(url, helloword)
        // 导入所需要的包
                .build();
        JavaFile daos = JavaFile.builder(url.substring(0, index) + ".Dao", dao)
                .build();
        String path = System.getProperty("user.dir") + "/src";
        fil.writeTo(new File(path));
        daos.writeTo(new File(path));
    }

    public static boolean setClss(Connection con, List<String> tableName,
            String packageName) { // 选择要生成的表
        Scanner sc = new Scanner(System.in);
        int num = 0;
        for (String str : tableName) {
            System.out.print(num++ + "-->" + str + "\t");
            if (num % 1 == 0) {
                System.out.println();
            }
        }
        System.out.println();
        List<Integer> tableId = new ArrayList<Integer>();
        boolean flag = false;
        while (!flag) {
            System.out.println("请选择要生成的表");
            int id = sc.nextInt();
            tableId.add(id);
            System.out.println("是否继续选择Y/S");
            String ys = sc.next();
            if (ys.equalsIgnoreCase("s")) {
                break;
            }
        }
        List<String> tab = new ArrayList<String>();
        for (Integer in : tableId) {
            if (tableId.size() == 0 || tableId.size() > tableName.size()) {
                throw new IndexOutOfBoundsException("下标越界");
            }
            tab.add(tableName.get(in));
        }
        int nums = 0;
        for (String str : tab) {
            CreateClass.createClass(con, str, packageName);
            nums++;
        }
        return nums == tab.size() ? true : false;
    }

    public static List<sequence> getSequence() {// 查询出来所有的序列
        List<sequence> select = getSelect(PersonalFactory.getConnection(),
                "select SEQUENCE_NAME from user_sequences",
                sequence.class.getName(), null);
        return select;
    }

}

/**
 * 使用这个所需要的几个内部类, 1,序列类 2,表类
 * 
 * */
class sequence {
    private String sequence_name;

    public String getSequence_name() {
        return sequence_name;
    }

    public void setSequence_name(String sequence_name) {
        this.sequence_name = sequence_name;
    }

}

class OracleTableName {
    private String table_name;

    public String getTable_name() {
        return table_name;
    }

    public void setTable_name(String table_name) {
        this.table_name = table_name;
    }

}



这是完整代码还有很多不足,望大家指出共同商讨