Flink 用户自定义函数实现

Posted by Chai's Blog on December 1, 2019

鉴于自定义函数函数在SQL中的强大语义,在SQL中有十分广泛的应用。Flink在其Table/SQL API中同样支持自定义函数,且根据Flink Forward Asia 2019的规划,在后续flink版本中,自定义函数将支持python语言以及兼容Hive的自定义函数. 本文简要介绍Flink中的UDF支持及实现。

自定义函数支持类型

Flink支持的自定义函数包括UDF,UDTF,UDAF等能力,主要通过继承UserDefinedFunction,定义相关方法,通过自动代码生成技术(Code Generation)生成辅助类,完成对方法的调用。

自定义函数 功能 实现类
UDF 对字段进行转义,输入一条记录,返回一条记录 ScalarFunction
UDTF 输入一条记录,返回一条/多条记录 TableFunction
UDAF 输入一条/多条记录,返回一条记录 AggregateFunction

自定义函数实现

UDF实现

Flink中通过ScalarFunction标量来支持UDF的实现。应用只需继承ScalarFunction,在其中定义eval函数即可。

其应用较为简单,其实现是通过Code Generation技术自动生成相关代码,完成对eval函数的调用。 在ProcessOperator中,其function即为自动生成的函数,在其processElement方法中,会调用我们自定义的eval方法,完成相关的转化逻辑。

一个简单的UDF函数如下:

1
2
3
4
5
6
7
public class ToUpperCase extends ScalarFunction {

    public String eval(String s) {

        return s.toUpperCase();
    }
}

调用堆栈如下图所示:

UDF堆栈

自动代码生产技术自动生成的function源码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
public class DataStreamCalcRule$19 extends org.apache.flink.streaming.api.functions.ProcessFunction {

    final com.cyq.learning.udf.function.ToUpperCase function_com$cyq$learning$udf$function$ToUpperCase$078037202ef3ddcbc4f27e57d323d0fd;


    final org.apache.flink.types.Row out =
            new org.apache.flink.types.Row(4);


    public DataStreamCalcRule$19() throws Exception {

        function_com$cyq$learning$udf$function$ToUpperCase$078037202ef3ddcbc4f27e57d323d0fd = (com.cyq.learning.udf.function.ToUpperCase)
                org.apache.flink.table.utils.EncodingUtils.decodeStringToObject(
                        "rO0ABXNyACljb20uY3lxLmxlYXJuaW5nLnVkZi5mdW5jdGlvbi5Ub1VwcGVyQ2FzZViZ_w3c2xj-AgAAeHIAL29yZy5hcGFjaGUuZmxpbmsudGFibGUuZnVuY3Rpb25zLlNjYWxhckZ1bmN0aW9uygzW306qntsCAAB4cgA0b3JnLmFwYWNoZS5mbGluay50YWJsZS5mdW5jdGlvbnMuVXNlckRlZmluZWRGdW5jdGlvboP22EvVTaRLAgAAeHA",
                        org.apache.flink.table.functions.UserDefinedFunction.class);

    }

    @Override
    public void open(org.apache.flink.configuration.Configuration parameters) throws Exception {

        function_com$cyq$learning$udf$function$ToUpperCase$078037202ef3ddcbc4f27e57d323d0fd.open(new org.apache.flink.table.functions.FunctionContext(getRuntimeContext()));


    }

    @Override
    public void processElement(Object _in1, org.apache.flink.streaming.api.functions.ProcessFunction.Context ctx, org.apache.flink.util.Collector c) throws Exception {
        org.apache.flink.types.Row in1 = (org.apache.flink.types.Row) _in1;

        boolean isNull$18 = (java.lang.Long) in1.getField(2) == null;
        long result$17;
        if (isNull$18) {
            result$17 = -1L;
        } else {
            result$17 = (java.lang.Long) in1.getField(2);
        }


        boolean isNull$11 = (java.lang.String) in1.getField(0) == null;
        java.lang.String result$10;
        if (isNull$11) {
            result$10 = "";
        } else {
            result$10 = (java.lang.String) (java.lang.String) in1.getField(0);
        }


        boolean isNull$13 = (java.lang.String) in1.getField(1) == null;
        java.lang.String result$12;
        if (isNull$13) {
            result$12 = "";
        } else {
            result$12 = (java.lang.String) (java.lang.String) in1.getField(1);
        }


        if (isNull$11) {
            out.setField(0, null);
        } else {
            out.setField(0, result$10);
        }


        if (isNull$13) {
            out.setField(1, null);
        } else {
            out.setField(1, result$12);
        }


        java.lang.String result$14 = function_com$cyq$learning$udf$function$ToUpperCase$078037202ef3ddcbc4f27e57d323d0fd.eval(
                isNull$13 ? null : (java.lang.String) result$12);


        boolean isNull$16 = result$14 == null;
        java.lang.String result$15;
        if (isNull$16) {
            result$15 = "";
        } else {
            result$15 = (java.lang.String) result$14;
        }


        if (isNull$16) {
            out.setField(2, null);
        } else {
            out.setField(2, result$15);
        }


        if (isNull$18) {
            out.setField(3, null);
        } else {
            out.setField(3, result$17);
        }

        c.collect(out);

    }

    @Override
    public void close() throws Exception {

        function_com$cyq$learning$udf$function$ToUpperCase$078037202ef3ddcbc4f27e57d323d0fd.close();


    }
}

UDTF实现

Flink的UDTF函数实现通过继承TableFunction完成,其完成一行拆分成多行的核心在于TableFunction的collect方法,该方法通过调用collector的collect方法可以将消息发送至下游operator,当接收到一条消息后,可以对消息拆分,然后将拆分后的多条消息分别发送至下游operator即可完成一条输入,多条输出的目标。

一个简单的UDTF实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
public class SplitFun extends TableFunction<Tuple2<String, Integer>> {

    private String separator = ",";
    public SplitFun(String separator) {
        this.separator = separator;
    }
    public void eval(String str) {
        for (String s : str.split(separator)) {
            collect(new Tuple2<String, Integer>(s, s.length()));
        }
    }
}

堆栈调用图如下:

UDTF堆栈

其自动代码生成代包含两个类(TableFunctionCollector$42 和DataStreamCorrelateRule$28),其中TableFunctionColletor即为将消息发送至下游operator的辅助collector, DataStreamCorrelateRule则是通过其processElement完成对TableFunction(SplitFun)的调用。

  • DataStreamCorrelateRule$28源码如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    135
    136
    137
    138
    139
    140
    141
    142
    143
    144
    145
    146
    147
    148
    149
    150
    151
    152
    153
    154
    155
    156
    157
    158
    159
    160
    161
    162
    163
    164
    165
    166
    167
    
    public class DataStreamCorrelateRule$28 extends org.apache.flink.streaming.api.functions.ProcessFunction {
    
        final org.apache.flink.table.runtime.TableFunctionCollector instance_org$apache$flink$table$runtime$TableFunctionCollector$22;
    
        final com.cyq.learning.udf.function.SplitFun function_com$cyq$learning$udf$function$SplitFun$f9ab41da087dbe9d50d61478f4bc638d;
    
    
        final org.apache.flink.types.Row out =
                new org.apache.flink.types.Row(6);
    
    
        public DataStreamCorrelateRule$28() throws Exception {
    
            function_com$cyq$learning$udf$function$SplitFun$f9ab41da087dbe9d50d61478f4bc638d = (com.cyq.learning.udf.function.SplitFun)
                    org.apache.flink.table.utils.EncodingUtils.decodeStringToObject(
                            "rO0ABXNyACZjb20uY3lxLmxlYXJuaW5nLnVkZi5mdW5jdGlvbi5TcGxpdEZ1bl5QHqGTdoFOAgABTAAJc2VwYXJhdG9ydAASTGphdmEvbGFuZy9TdHJpbmc7eHIALm9yZy5hcGFjaGUuZmxpbmsudGFibGUuZnVuY3Rpb25zLlRhYmxlRnVuY3Rpb27qA9IDIKYg-gIAAUwACWNvbGxlY3RvcnQAIUxvcmcvYXBhY2hlL2ZsaW5rL3V0aWwvQ29sbGVjdG9yO3hyADRvcmcuYXBhY2hlLmZsaW5rLnRhYmxlLmZ1bmN0aW9ucy5Vc2VyRGVmaW5lZEZ1bmN0aW9ug_bYS9VNpEsCAAB4cHB0AAFA",
                            org.apache.flink.table.functions.UserDefinedFunction.class);
    
    
        }
    
    
        public DataStreamCorrelateRule$28(org.apache.flink.table.runtime.TableFunctionCollector arg0) throws Exception {
            this();
            instance_org$apache$flink$table$runtime$TableFunctionCollector$22 = arg0;
    
        }
    
    
        @Override
        public void open(org.apache.flink.configuration.Configuration parameters) throws Exception {
    
            function_com$cyq$learning$udf$function$SplitFun$f9ab41da087dbe9d50d61478f4bc638d.open(new org.apache.flink.table.functions.FunctionContext(getRuntimeContext()));
    
    
        }
    
        @Override
        public void processElement(Object _in1, org.apache.flink.streaming.api.functions.ProcessFunction.Context ctx, org.apache.flink.util.Collector c) throws Exception {
            org.apache.flink.types.Row in1 = (org.apache.flink.types.Row) _in1;
    
            boolean isNull$15 = (java.lang.Long) in1.getField(2) == null;
            long result$14;
            if (isNull$15) {
                result$14 = -1L;
            } else {
                result$14 = (java.lang.Long) in1.getField(2);
            }
    
    
            boolean isNull$11 = (java.lang.String) in1.getField(0) == null;
            java.lang.String result$10;
            if (isNull$11) {
                result$10 = "";
            } else {
                result$10 = (java.lang.String) (java.lang.String) in1.getField(0);
            }
    
    
            boolean isNull$13 = (java.lang.String) in1.getField(1) == null;
            java.lang.String result$12;
            if (isNull$13) {
                result$12 = "";
            } else {
                result$12 = (java.lang.String) (java.lang.String) in1.getField(1);
            }
    
    
            boolean isNull$17 = (java.lang.Long) in1.getField(3) == null;
            long result$16;
            if (isNull$17) {
                result$16 = -1L;
            } else {
                result$16 = (long) org.apache.calcite.runtime.SqlFunctions.toLong((java.lang.Long) in1.getField(3));
            }
    
    
            function_com$cyq$learning$udf$function$SplitFun$f9ab41da087dbe9d50d61478f4bc638d.setCollector(instance_org$apache$flink$table$runtime$TableFunctionCollector$22);
    
    
            java.lang.String result$25;
            boolean isNull$26;
            if (false) {
                result$25 = "";
                isNull$26 = true;
            } else {
    
                boolean isNull$24 = (java.lang.String) in1.getField(1) == null;
                java.lang.String result$23;
                if (isNull$24) {
                    result$23 = "";
                } else {
                    result$23 = (java.lang.String) (java.lang.String) in1.getField(1);
                }
    
                result$25 = result$23;
                isNull$26 = isNull$24;
            }
    
            function_com$cyq$learning$udf$function$SplitFun$f9ab41da087dbe9d50d61478f4bc638d.eval(isNull$26 ? null : (java.lang.String) result$25);
    
    
            boolean hasOutput = instance_org$apache$flink$table$runtime$TableFunctionCollector$22.isCollected();
            if (!hasOutput) {
    
    
                if (isNull$11) {
                    out.setField(0, null);
                } else {
                    out.setField(0, result$10);
                }
    
    
                if (isNull$13) {
                    out.setField(1, null);
                } else {
                    out.setField(1, result$12);
                }
    
    
                if (isNull$15) {
                    out.setField(2, null);
                } else {
                    out.setField(2, result$14);
                }
    
    
                java.lang.Long result$27;
                if (isNull$17) {
                    result$27 = null;
                } else {
                    result$27 = result$16;
                }
    
                if (isNull$17) {
                    out.setField(3, null);
                } else {
                    out.setField(3, result$27);
                }
    
    
                if (true) {
                    out.setField(4, null);
                } else {
                    out.setField(4, "");
                }
    
    
                if (true) {
                    out.setField(5, null);
                } else {
                    out.setField(5, -1);
                }
    
                c.collect(out);
            }
    
        }
    
        @Override
        public void close() throws Exception {
    
            function_com$cyq$learning$udf$function$SplitFun$f9ab41da087dbe9d50d61478f4bc638d.close();
    
    
        }
    }
    
  • TableFunctionCollector$42源码如下:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    117
    118
    119
    120
    121
    122
    123
    124
    125
    126
    127
    128
    129
    130
    131
    132
    133
    134
    
    public class TableFunctionCollector$42 extends org.apache.flink.table.runtime.TableFunctionCollector {
    
    
        final org.apache.flink.types.Row out =
                new org.apache.flink.types.Row(6);
    
    
        public TableFunctionCollector$42() throws Exception {
    
    
        }
    
        @Override
        public void open(org.apache.flink.configuration.Configuration parameters) throws Exception {
    
    
        }
    
        @Override
        public void collect(Object record) throws Exception {
            super.collect(record);
            org.apache.flink.types.Row in1 = (org.apache.flink.types.Row) getInput();
            org.apache.flink.api.java.tuple.Tuple2 in2 = (org.apache.flink.api.java.tuple.Tuple2) record;
    
            boolean isNull$34 = (java.lang.Long) in1.getField(2) == null;
            long result$33;
            if (isNull$34) {
                result$33 = -1L;
            } else {
                result$33 = (java.lang.Long) in1.getField(2);
            }
    
    
            boolean isNull$30 = (java.lang.String) in1.getField(0) == null;
            java.lang.String result$29;
            if (isNull$30) {
                result$29 = "";
            } else {
                result$29 = (java.lang.String) (java.lang.String) in1.getField(0);
            }
    
    
            boolean isNull$32 = (java.lang.String) in1.getField(1) == null;
            java.lang.String result$31;
            if (isNull$32) {
                result$31 = "";
            } else {
                result$31 = (java.lang.String) (java.lang.String) in1.getField(1);
            }
    
    
            boolean isNull$36 = (java.lang.Long) in1.getField(3) == null;
            long result$35;
            if (isNull$36) {
                result$35 = -1L;
            } else {
                result$35 = (long) org.apache.calcite.runtime.SqlFunctions.toLong((java.lang.Long) in1.getField(3));
            }
    
    
            if (isNull$30) {
                out.setField(0, null);
            } else {
                out.setField(0, result$29);
            }
    
    
            if (isNull$32) {
                out.setField(1, null);
            } else {
                out.setField(1, result$31);
            }
    
    
            if (isNull$34) {
                out.setField(2, null);
            } else {
                out.setField(2, result$33);
            }
    
    
            java.lang.Long result$41;
            if (isNull$36) {
                result$41 = null;
            } else {
                result$41 = result$35;
            }
    
            if (isNull$36) {
                out.setField(3, null);
            } else {
                out.setField(3, result$41);
            }
    
    
            boolean isNull$38 = (java.lang.String) in2.f0 == null;
            java.lang.String result$37;
            if (isNull$38) {
                result$37 = "";
            } else {
                result$37 = (java.lang.String) (java.lang.String) in2.f0;
            }
    
            if (isNull$38) {
                out.setField(4, null);
            } else {
                out.setField(4, result$37);
            }
    
    
            boolean isNull$40 = (java.lang.Integer) in2.f1 == null;
            int result$39;
            if (isNull$40) {
                result$39 = -1;
            } else {
                result$39 = (java.lang.Integer) in2.f1;
            }
    
            if (isNull$40) {
                out.setField(5, null);
            } else {
                out.setField(5, result$39);
            }
    
            getCollector().collect(out);
    
        }
    
        @Override
        public void close() throws Exception {
    
    
        }
    }
    

UDAF实现

Flink 的UDAF的实现是通过继承AggregateFunction并实现相关的函数逻辑完成统计分析功能。UDAF实现时,需要实现的函数较多,不同的场景下需要实现的方法也不尽相同,因此其实现也较为复杂。可参考UserDefinedFunction

一个简单的UDAF实现可能如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
public class AggFunctionTest {
    public static void main(String[] args) throws Exception {
        StreamTableEnvironment streamTableEnvironment = EnvironmentUtil.getStreamTableEnvironment();

        TableSchema tableSchema = new TableSchema(new String[]{"product_id", "number", "price", "time"},
                new TypeInformation[]{Types.STRING, Types.LONG, Types.INT, Types.SQL_TIMESTAMP});

        Map<String, String> kakfaProps = new HashMap<>();
        kakfaProps.put("bootstrap.servers", BROKER_SERVERS);

        FlinkKafkaConsumerBase flinkKafkaConsumerBase = SourceUtil.createKafkaSourceStream(TOPIC, kakfaProps, tableSchema);

        DataStream<Row> stream = streamTableEnvironment.execEnv().addSource(flinkKafkaConsumerBase);

        streamTableEnvironment.registerFunction("wAvg", new WeightedAvg());

        Table streamTable = streamTableEnvironment.fromDataStream(stream, "product_id,number,price,rowtime.rowtime");

        streamTableEnvironment.registerTable("product", streamTable);
        Table result = streamTableEnvironment.sqlQuery("SELECT product_id, wAvg(number,price) as avgPrice FROM product group by product_id");


        final TableSchema tableSchemaResult = new TableSchema(new String[]{"product_id", "avg"},
                new TypeInformation[]{Types.STRING, Types.LONG});
        final TypeInformation<Row> typeInfoResult = tableSchemaResult.toRowType();

        // Here a toRestarctStream is need, Because the Talbe result is not an append table,
        DataStream ds = streamTableEnvironment.toRetractStream(result, typeInfoResult);
        ds.print();
        streamTableEnvironment.execEnv().execute("Flink Agg Function Test");


    }
}

其中相关的结构及函数如下:

1
2
3
4
public class WeightedAvgAccum {
    public long sum = 0;
    public int count = 0;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
public class WeightedAvg extends AggregateFunction<Long, WeightedAvgAccum> {
    @Override
    public WeightedAvgAccum createAccumulator() {
        return new WeightedAvgAccum();
    }
    @Override
    public Long getValue(WeightedAvgAccum acc) {
        if (acc.count == 0) {
            return null;
        } else {
            return acc.sum / acc.count;
        }
    }
    public void accumulate(WeightedAvgAccum acc, long iValue, int iWeight) {
        acc.sum += iValue * iWeight;
        acc.count += iWeight;
    }
    public void retract(WeightedAvgAccum acc, long iValue, int iWeight) {
        acc.sum -= iValue * iWeight;
        acc.count -= iWeight;
    }
    public void merge(WeightedAvgAccum acc, Iterable<WeightedAvgAccum> it) {
        Iterator<WeightedAvgAccum> iter = it.iterator();
        while (iter.hasNext()) {
            WeightedAvgAccum a = iter.next();
            acc.count += a.count;
            acc.sum += a.sum;
        }
    }
    public void resetAccumulator(WeightedAvgAccum acc) {
        acc.count = 0;
        acc.sum = 0L;
    }
}

注意: 在AggFunctionTest中,由于我们执行的sql,是一个汇总统计分析是一个全局的统计,在此场景下转化为流进行输出,我们采用的是toRetractStream流,即可撤回流模式,因为是全局统计,统计的结果实时发生发变化,所以无法使用appendStream。关于toRetractStream可以参考TalbeToStreamConversion

其执行由GroupAggProcessFunction的processelement方法完成,核心逻辑如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
override def processElement(
    inputC: CRow,
    ctx: ProcessFunction[CRow, CRow]#Context,
    out: Collector[CRow]): Unit = {

  val currentTime = ctx.timerService().currentProcessingTime()
  // register state-cleanup timer
  processCleanupTimer(ctx, currentTime)

  val input = inputC.row

  // get accumulators and input counter
  var accumulators = state.value()
  var inputCnt = cntState.value()

  if (null == accumulators) {
    // Don't create a new accumulator for a retraction message. This
    // might happen if the retraction message is the first message for the
    // key or after a state clean up.
    if (!inputC.change) {
      return
    }
    // first accumulate message
    firstRow = true
    accumulators = function.createAccumulators()
  } else {
    firstRow = false
  }

  if (null == inputCnt) {
    inputCnt = 0L
  }

  // Set group keys value to the final output
  function.setForwardedFields(input, newRow.row)
  function.setForwardedFields(input, prevRow.row)

  // Set previous aggregate result to the prevRow
  function.setAggregationResults(accumulators, prevRow.row)

  // update aggregate result and set to the newRow
  if (inputC.change) {
    inputCnt += 1
    // accumulate input
    function.accumulate(accumulators, input)
    function.setAggregationResults(accumulators, newRow.row)
  } else {
    inputCnt -= 1
    // retract input
    function.retract(accumulators, input)
    function.setAggregationResults(accumulators, newRow.row)
  }

  if (inputCnt != 0) {
    // we aggregated at least one record for this key

    // update the state
    state.update(accumulators)
    cntState.update(inputCnt)

    // if this was not the first row
    if (!firstRow) {
      if (prevRow.row.equals(newRow.row) && !stateCleaningEnabled) {
        // newRow is the same as before and state cleaning is not enabled.
        // We emit nothing
        // If state cleaning is enabled, we have to emit messages to prevent too early
        // state eviction of downstream operators.
        return
      } else {
        // retract previous result
        if (generateRetraction) {
          out.collect(prevRow)
        }
      }
    }
    // emit the new result
    out.collect(newRow)

  } else {
    // we retracted the last record for this key
    // sent out a delete message
    out.collect(prevRow)
    // and clear all state
    state.clear()
    cntState.clear()
  }
}

从以上代码可以看出其相关的统计功能主要是通过function完成状态更新,输出判定等,其中Function依旧是通过自动代码生成技术生成,function源码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
public final class NonWindowedAggregationHelper$17 extends org.apache.flink.table.runtime.aggregate.GeneratedAggregations {


   final com.cyq.learning.udf.function.WeightedAvg function_com$cyq$learning$udf$function$WeightedAvg$776eb73449c07a5f410bff97fc0bb881;


       private final org.apache.flink.table.runtime.aggregate.SingleElementIterable<com.cyq.learning.udf.function.WeightedAvgAccum> accIt0 =
               new org.apache.flink.table.runtime.aggregate.SingleElementIterable<com.cyq.learning.udf.function.WeightedAvgAccum>();

      public NonWindowedAggregationHelper$17() throws Exception {

           function_com$cyq$learning$udf$function$WeightedAvg$776eb73449c07a5f410bff97fc0bb881 = (com.cyq.learning.udf.function.WeightedAvg)
                   org.apache.flink.table.utils.EncodingUtils.decodeStringToObject(
                           "rO0ABXNyACljb20uY3lxLmxlYXJuaW5nLnVkZi5mdW5jdGlvbi5XZWlnaHRlZEF2Z7o-ESHvqA7TAgAAeHIAMm9yZy5hcGFjaGUuZmxpbmsudGFibGUuZnVuY3Rpb25zLkFnZ3JlZ2F0ZUZ1bmN0aW9uYPYkANq5VRgCAAB4cgA0b3JnLmFwYWNoZS5mbGluay50YWJsZS5mdW5jdGlvbnMuVXNlckRlZmluZWRGdW5jdGlvboP22EvVTaRLAgAAeHA",
                           org.apache.flink.table.functions.UserDefinedFunction.class);

   }


   public final void open(
              org.apache.flink.api.common.functions.RuntimeContext ctx) throws Exception {

          function_com$cyq$learning$udf$function$WeightedAvg$776eb73449c07a5f410bff97fc0bb881.open(new org.apache.flink.table.functions.FunctionContext(ctx));

       }


   public final void setAggregationResults(
            org.apache.flink.types.Row accs,
              org.apache.flink.types.Row output) throws Exception {

          org.apache.flink.table.functions.AggregateFunction baseClass0 =
               (org.apache.flink.table.functions.AggregateFunction) function_com$cyq$learning$udf$function$WeightedAvg$776eb73449c07a5f410bff97fc0bb881;
           com.cyq.learning.udf.function.WeightedAvgAccum acc0 = (com.cyq.learning.udf.function.WeightedAvgAccum) accs.getField(0);

           output.setField(
               1,
               baseClass0.getValue(acc0));


       }

   public final void accumulate(
           org.apache.flink.types.Row accs,
            org.apache.flink.types.Row input) throws Exception {

           com.cyq.learning.udf.function.WeightedAvgAccum acc0 = (com.cyq.learning.udf.function.WeightedAvgAccum) accs.getField(0);

       function_com$cyq$learning$udf$function$WeightedAvg$776eb73449c07a5f410bff97fc0bb881.accumulate(acc0
               , (java.lang.Long) input.getField(1), (java.lang.Integer) input.getField(2));

    }


       public final void retract(
               org.apache.flink.types.Row accs,
               org.apache.flink.types.Row input) throws Exception {
   }

    public final org.apache.flink.types.Row createAccumulators() throws Exception {

           org.apache.flink.types.Row accs =
                   new org.apache.flink.types.Row(1);
           com.cyq.learning.udf.function.WeightedAvgAccum acc0 = (com.cyq.learning.udf.function.WeightedAvgAccum) function_com$cyq$learning$udf$function$WeightedAvg$776eb73449c07a5f410bff97fc0bb881.createAccumulator();
       accs.setField(
               0,
                acc0);
           return accs;
   }

   public final void setForwardedFields(
               org.apache.flink.types.Row input,
               org.apache.flink.types.Row output) {

           output.setField(
                   0,
                   input.getField(0));
       }

       public final org.apache.flink.types.Row createOutputRow() {
       return new org.apache.flink.types.Row(2);
   }


       public final org.apache.flink.types.Row mergeAccumulatorsPair(
               org.apache.flink.types.Row a,
               org.apache.flink.types.Row b) {

        return a;

    }

       public final void resetAccumulator(
               org.apache.flink.types.Row accs) throws Exception {
       }

       public final void cleanup() throws Exception {

       }

       public final void close() throws Exception {

           function_com$cyq$learning$udf$function$WeightedAvg$776eb73449c07a5f410bff97fc0bb881.close();

       }
   }

以上源码读者可参考Learning

读者通过如上的分析及各种自动生成的代码大致可了解到自定义函数的执行流程。其中较为难理解的部分主要是高大上的自动代码生成技术。由于Code Generation在sql执行中性能优越性以及实现功能的灵活性,本技术在Spark和Flink中都有广泛应用。后续笔者也会对Code Generation进行分析。