SparkSQL用UDAF实现Bitmap函数

共 4762字,需浏览 10分钟

 ·

2021-04-27 10:47

创建测试表

使用phoenix在HBase中创建测试表,字段使用VARBINARY类型

CREATE TABLE IF NOT EXISTS test_binary (date VARCHAR NOT NULL,dist_mem VARBINARY CONSTRAINT test_binary_pk PRIMARY KEY (date) ) SALT_BUCKETS=6;

创建完成后使用RoaringBitmap序列化数据存入数据库:

实现自定义聚合函数bitmap

import org.apache.spark.sql.Row;import org.apache.spark.sql.expressions.MutableAggregationBuffer;import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;import org.apache.spark.sql.types.DataType;import org.apache.spark.sql.types.DataTypes;import org.apache.spark.sql.types.StructField;import org.apache.spark.sql.types.StructType;import org.roaringbitmap.RoaringBitmap; import java.io.*;import java.util.ArrayList;import java.util.List; /** * 实现自定义聚合函数Bitmap */public class UdafBitMap extends UserDefinedAggregateFunction {    @Override    public StructType inputSchema() {        List<StructField> structFields = new ArrayList<>();        structFields.add(DataTypes.createStructField("field", DataTypes.BinaryType, true));        return DataTypes.createStructType(structFields);    }     @Override    public StructType bufferSchema() {        List<StructField> structFields = new ArrayList<>();        structFields.add(DataTypes.createStructField("field", DataTypes.BinaryType, true));        return DataTypes.createStructType(structFields);    }     @Override    public DataType dataType() {        return DataTypes.LongType;    }     @Override    public boolean deterministic() {        //是否强制每次执行的结果相同        return false;    }     @Override    public void initialize(MutableAggregationBuffer buffer) {        //初始化        buffer.update(0, null);    }     @Override    public void update(MutableAggregationBuffer buffer, Row input) {        // 相同的executor间的数据合并        // 1. 输入为空直接返回不更新        Object in = input.get(0);        if(in == null){            return ;        }        // 2. 源为空则直接更新值为输入        byte[] inBytes = (byte[]) in;        Object out = buffer.get(0);        if(out == null){            buffer.update(0, inBytes);            return ;        }        // 3. 源和输入都不为空使用bitmap去重合并        byte[] outBytes = (byte[]) out;        byte[] result = outBytes;        RoaringBitmap outRR = new RoaringBitmap();        RoaringBitmap inRR = new RoaringBitmap();        try {            outRR.deserialize(new DataInputStream(new ByteArrayInputStream(outBytes)));            inRR.deserialize(new DataInputStream(new ByteArrayInputStream(inBytes)));            outRR.or(inRR);            ByteArrayOutputStream bos = new ByteArrayOutputStream();            outRR.serialize(new DataOutputStream(bos));            result = bos.toByteArray();        } catch (IOException e) {            e.printStackTrace();        }        buffer.update(0, result);    }     @Override    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {        //不同excutor间的数据合并        update(buffer1, buffer2);    }     @Override    public Object evaluate(Row buffer) {        //根据Buffer计算结果        long r = 0l;        Object val = buffer.get(0);        if (val != null) {            RoaringBitmap rr = new RoaringBitmap();            try {                rr.deserialize(new DataInputStream(new ByteArrayInputStream((byte[]) val)));                r = rr.getLongCardinality();            } catch (IOException e) {                e.printStackTrace();            }        }        return r;    }}

调用示例

 /**     * 使用自定义函数解析bitmap     *     * @param sparkSession     * @return     */    private static void udafBitmap(SparkSession sparkSession) {        try {            Properties prop = PropUtil.loadProp(DB_PHOENIX_CONF_FILE);            // JDBC连接属性            Properties connProp = new Properties();            connProp.put("driver", prop.getProperty(DB_PHOENIX_DRIVER));            connProp.put("user", prop.getProperty(DB_PHOENIX_USER));            connProp.put("password", prop.getProperty(DB_PHOENIX_PASS));            connProp.put("fetchsize", prop.getProperty(DB_PHOENIX_FETCHSIZE));            // 注册自定义聚合函数            sparkSession.udf().register("bitmap",new UdafBitMap());            sparkSession                    .read()                    .jdbc(prop.getProperty(DB_PHOENIX_URL), "test_binary", connProp)                    // sql中必须使用global_temp.表名,否则找不到                    .createOrReplaceGlobalTempView("test_binary");            //sparkSession.sql("select YEAR(TO_DATE(date)) year,bitmap(dist_mem) memNum from global_temp.test_binary group by YEAR(TO_DATE(date))").show();            sparkSession.sql("select date,bitmap(dist_mem) memNum from global_temp.test_binary group by date").show();        } catch (Exception e) {            e.printStackTrace();        }    }

结果:

浏览 26
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报