Skip to content

Instantly share code, notes, and snippets.

@hzhu212
Created May 30, 2023 09:29
Show Gist options
  • Save hzhu212/76bd48b0a71ff2ae59b85f066903307b to your computer and use it in GitHub Desktop.
Save hzhu212/76bd48b0a71ff2ae59b85f066903307b to your computer and use it in GitHub Desktop.
Hive UDXF
package com.example.udf;
import com.google.common.base.Throwables;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.BytesWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.zip.DeflaterOutputStream;
@Description(
name = "compress",
value = "_FUNC_(BINARY) -> BINARY, uses Java Deflater",
extended = "Example:\n > SELECT base64(_FUNC_(encode('hello', 'ASCII'))) -> 'eJzLSM3JyQcABiwCFQ=='"
)
public class DeflateUDF extends GenericUDF {
static final Logger logger = LoggerFactory.getLogger(DeflateUDF.class.getName());
private transient BinaryObjectInspector binOI;
private final transient ByteArrayOutputStream bos = new ByteArrayOutputStream(10 * 1024);
private final transient BytesWritable result = new BytesWritable();
@Override
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
if (arguments.length != 1) {
throw new UDFArgumentLengthException(String.format("The function accepts exactly one binary type argument, but got %s.", arguments.length));
} else if (!(arguments[0] instanceof PrimitiveObjectInspector) && ((PrimitiveObjectInspector) arguments[0]).getPrimitiveCategory() == PrimitiveObjectInspector.PrimitiveCategory.BINARY) {
throw new UDFArgumentTypeException(0, String.format("The function accepts exactly one binary type argument, but got %s.", arguments[0].getTypeName()));
}
binOI = ((BinaryObjectInspector) arguments[0]);
return PrimitiveObjectInspectorFactory.writableBinaryObjectInspector;
}
@Override
public Object evaluate(DeferredObject[] arguments) throws HiveException {
if (arguments == null || arguments.length < 1) {
throw new HiveException("arguments is empty");
}
final Object arg0 = arguments[0].get();
if (arg0 == null) {
logger.warn("argument is null");
return null;
}
bos.reset();
final byte[] bin = binOI.getPrimitiveJavaObject(arg0);
try {
DeflaterOutputStream zip = new DeflaterOutputStream(bos);
zip.write(bin);
zip.finish();
} catch (IOException e) {
logger.error("failed to compress data (size={}). error: {}", bin.length, Throwables.getStackTraceAsString(e));
return null;
}
result.setSize(bos.size());
result.set(bos.toByteArray(), 0, bos.size());
return result;
}
@Override
public String getDisplayString(String[] children) {
return "Usage:" + this.getClass().getName() + "(" + children[0] + ")";
}
@Override
public void close() throws IOException {
super.close();
bos.close();
}
}
package com.example.udf;
import com.google.common.base.Throwables;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentLengthException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.BytesWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.util.zip.InflaterInputStream;
@Description(
name = "decompress",
value = "_FUNC_(BINARY) -> BINARY, uses Java Inflater",
extended = "Example:\n > SELECT decode(_FUNC_(unbase64('eJzLSM3JyQcABiwCFQ==')), 'ASCII') -> 'hello'"
)
public class InflateUDF extends GenericUDF {
static final Logger logger = LoggerFactory.getLogger(InflateUDF.class.getName());
private transient BinaryObjectInspector binOI;
private final transient ByteArrayOutputStream bos = new ByteArrayOutputStream(10 * 1024);
private final transient BytesWritable result = new BytesWritable();
@Override
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
if (arguments.length != 1) {
throw new UDFArgumentLengthException(String.format("The function accepts exactly one binary type argument, but got %s.", arguments.length));
} else if (!(arguments[0] instanceof PrimitiveObjectInspector) && ((PrimitiveObjectInspector) arguments[0]).getPrimitiveCategory() == PrimitiveObjectInspector.PrimitiveCategory.BINARY) {
throw new UDFArgumentTypeException(0, String.format("The function accepts exactly one binary type argument, but got %s.", arguments[0].getTypeName()));
}
binOI = ((BinaryObjectInspector) arguments[0]);
return PrimitiveObjectInspectorFactory.writableBinaryObjectInspector;
}
@Override
public Object evaluate(DeferredObject[] arguments) throws HiveException {
if (arguments == null || arguments.length < 1) {
throw new HiveException("arguments is empty");
}
final Object arg0 = arguments[0].get();
if (arg0 == null) {
logger.warn("argument is null");
return null;
}
bos.reset();
final byte[] bin = binOI.getPrimitiveJavaObject(arg0);
try {
final ByteArrayInputStream in = new ByteArrayInputStream(bin);
final InflaterInputStream unzip = new InflaterInputStream(in);
byte[] buffer = new byte[512];
int n;
while ((n = unzip.read(buffer)) >= 0) {
bos.write(buffer, 0, n);
}
unzip.close();
} catch (IOException e) {
logger.error("failed to decompress data (size={}). error: {}", bin.length, Throwables.getStackTraceAsString(e));
return null;
}
result.setSize(bos.size());
result.set(bos.toByteArray(), 0, bos.size());
return result;
}
@Override
public String getDisplayString(String[] children) {
return "Usage:" + this.getClass().getName() + "(" + children[0] + ")";
}
@Override
public void close() throws IOException {
super.close();
bos.close();
}
}
package com.example.udaf;
import com.google.common.collect.Sets;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.*;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
import java.util.HashMap;
import java.util.HashSet;
@Description(
name = "map_sum",
value = "_FUNC_(x) - sum the values of multiple maps grouping by map keys, and return the summed map"
)
public class MapSumUDAF extends AbstractGenericUDAFResolver {
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
if (parameters.length != 1) {
throw new UDFArgumentTypeException(parameters.length - 1, "Exactly one argument is expected.");
}
if (!parameters[0].getCategory().equals(Category.MAP)) {
throw new UDFArgumentTypeException(0, "One argument is expected, taking an MAP as an argument");
}
StandardMapObjectInspector inputMap = (StandardMapObjectInspector) TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameters[0]);
ObjectInspector keyOI = inputMap.getMapKeyObjectInspector();
ObjectInspector valueOI = inputMap.getMapValueObjectInspector();
if (keyOI.getCategory() != Category.PRIMITIVE) {
throw new UDFArgumentTypeException(0, "Map key must be PRIMITIVE, but " + keyOI.getCategory().name() + " was passed.");
}
if (valueOI.getCategory() != Category.PRIMITIVE) {
throw new UDFArgumentTypeException(0, "Map value must be PRIMITIVE, but " + valueOI.getCategory().name() + " was passed.");
}
return new MapSumEvaluator();
}
public static class MapSumEvaluator extends GenericUDAFEvaluator {
public static HashSet<PrimitiveObjectInspector.PrimitiveCategory> intTypes = Sets.newHashSet(
PrimitiveObjectInspector.PrimitiveCategory.LONG,
PrimitiveObjectInspector.PrimitiveCategory.INT,
PrimitiveObjectInspector.PrimitiveCategory.SHORT,
PrimitiveObjectInspector.PrimitiveCategory.BYTE
);
private StandardMapObjectInspector inputMapOI;
private PrimitiveObjectInspector inputKeyOI;
private PrimitiveObjectInspector inputValOI;
private StandardMapObjectInspector midMapOI;
private PrimitiveObjectInspector outputValOI;
static class MapAggBuffer implements AggregationBuffer {
HashMap<Object, Object> collectMap = new HashMap<Object, Object>();
}
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
super.init(m, parameters);
// map 阶段读取 SQL 输入数据,sum 值类型统一为 long 或 double
StandardMapObjectInspector outputMapOI;
if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
inputMapOI = (StandardMapObjectInspector) parameters[0];
inputKeyOI = (PrimitiveObjectInspector) inputMapOI.getMapKeyObjectInspector();
inputValOI = (PrimitiveObjectInspector) inputMapOI.getMapValueObjectInspector();
if (intTypes.contains(inputValOI.getPrimitiveCategory())) {
outputValOI = PrimitiveObjectInspectorFactory.javaLongObjectInspector;
} else {
outputValOI = PrimitiveObjectInspectorFactory.javaDoubleObjectInspector;
}
outputMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(inputKeyOI, outputValOI);
} else {
// 其他阶段读取中间结果,与最终输出类型一致,无需额外处理
midMapOI = ((StandardMapObjectInspector) parameters[0]);
inputKeyOI = ((PrimitiveObjectInspector) midMapOI.getMapKeyObjectInspector());
inputValOI = ((PrimitiveObjectInspector) midMapOI.getMapValueObjectInspector());
outputMapOI = midMapOI;
outputValOI = ((PrimitiveObjectInspector) outputMapOI.getMapValueObjectInspector());
}
return outputMapOI;
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
return new MapAggBuffer();
}
@Override
public void reset(AggregationBuffer buff) throws HiveException {
MapAggBuffer aggBuffer = (MapAggBuffer) buff;
aggBuffer.collectMap.clear();
}
@Override
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
Object mpObj = parameters[0];
if (mpObj != null) {
MapAggBuffer myAgg = (MapAggBuffer) agg;
HashMap<Object, Object> mp = (HashMap) inputMapOI.getMap(mpObj);
for (Object k : mp.keySet()) {
Object v = mp.get(k);
putIntoMap(k, v, myAgg);
}
}
}
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
MapAggBuffer myAgg = (MapAggBuffer) agg;
return new HashMap<>(myAgg.collectMap);
}
@Override
public void merge(AggregationBuffer agg, Object partial) throws HiveException {
MapAggBuffer myAgg = (MapAggBuffer) agg;
HashMap<Object, Object> partialMap = (HashMap) midMapOI.getMap(partial);
for (Object key : partialMap.keySet()) {
putIntoMap(key, partialMap.get(key), myAgg);
}
}
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
MapAggBuffer myAgg = (MapAggBuffer) agg;
return new HashMap<>(myAgg.collectMap);
}
private void putIntoMap(Object key, Object val, MapAggBuffer myAgg) {
Object keyCopy = ObjectInspectorUtils.copyToStandardObject(key, this.inputKeyOI);
Object valCopy = ObjectInspectorUtils.copyToStandardObject(val, this.inputValOI);
if (myAgg.collectMap.containsKey(keyCopy)) {
if (outputValOI.getPrimitiveCategory() == PrimitiveObjectInspector.PrimitiveCategory.LONG) {
myAgg.collectMap.put(keyCopy, ((Number) myAgg.collectMap.get(keyCopy)).longValue() + ((Number) valCopy).longValue());
} else {
myAgg.collectMap.put(keyCopy, ((Number) myAgg.collectMap.get(keyCopy)).doubleValue() + ((Number) valCopy).doubleValue());
}
} else {
if (outputValOI.getPrimitiveCategory() == PrimitiveObjectInspector.PrimitiveCategory.LONG) {
myAgg.collectMap.put(keyCopy, ((Number) valCopy).longValue());
} else {
myAgg.collectMap.put(keyCopy, ((Number) valCopy).doubleValue());
}
}
}
}
}
package com.example.udaf;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.*;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import java.util.HashMap;
import java.util.Map;
/**
* 支持无限层递归嵌套的 collect_map。最后一个字段为 value,前面所有字段为嵌套 key。
* 聚合时默认对 value 求和,如果希望计算每个 key 的 count,可以将输入的 value 置为 1。
* 示例:
* - nested_map(sex, cnt) -> {'male': 10000, 'female': 5000}
* - nested_map(sex, country, cnt) -> {'male': {'US': 200, 'JP': 100, ...}, 'female': {'US': 100, 'JP': 50, ...}}
*/
@Description(name = "nested_map", value = "_FUNC_(k1, k2, k3..., v) nested map")
public class NestedMapUDAF extends AbstractGenericUDAFResolver {
@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] paramsTypes) throws SemanticException {
for (int i = 0; i < paramsTypes.length; i++) {
if (paramsTypes[i].getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentTypeException(i, "key or value must be primitive type");
}
}
return new NestedMapUDAFEvaluator();
}
public static class NestedMapUDAFEvaluator extends GenericUDAFEvaluator {
private int nestedMapDepth;
private PrimitiveObjectInspector[] inputsInspectors;
private PrimitiveObjectInspector lastValueInspector;
private StandardMapObjectInspector mapObjectInspector;
static class NestedMapCollector extends AbstractAggregationBuffer {
HashMap<Object, Object> collectMap = new HashMap<>();
}
@Override
public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException {
super.init(mode, parameters);
if (Mode.PARTIAL1 == mode || Mode.COMPLETE == mode) {
if (parameters.length < 2) {
throw new HiveException("input parameter length should not less 2");
}
switch (((PrimitiveObjectInspector) parameters[parameters.length - 1]).getPrimitiveCategory()) {
case LONG:
case INT:
break;
default:
throw new HiveException("function last value should be int or long");
}
inputsInspectors = new PrimitiveObjectInspector[parameters.length - 1];
for (int i = 0; i < parameters.length - 1; i++) {
inputsInspectors[i] = (PrimitiveObjectInspector) parameters[i];
}
// last parameter is last map value
nestedMapDepth = parameters.length - 1;
lastValueInspector = (PrimitiveObjectInspector) parameters[parameters.length - 1];
PrimitiveObjectInspector keyObjectInspector = (PrimitiveObjectInspector) parameters[parameters.length - 2];
StandardMapObjectInspector nestedMapInspector = ObjectInspectorFactory.getStandardMapObjectInspector(ObjectInspectorUtils.getStandardObjectInspector(keyObjectInspector),
PrimitiveObjectInspectorFactory.getPrimitiveJavaObjectInspector(TypeInfoFactory.longTypeInfo));
for (int i = parameters.length - 3; i >= 0; i--) {
keyObjectInspector = (PrimitiveObjectInspector) parameters[i];
nestedMapInspector = ObjectInspectorFactory.getStandardMapObjectInspector(ObjectInspectorUtils.getStandardObjectInspector(keyObjectInspector), nestedMapInspector);
}
return nestedMapInspector;
} else {
nestedMapDepth = 1;
mapObjectInspector = (StandardMapObjectInspector) parameters[0];
while (mapObjectInspector.getMapValueObjectInspector().getCategory() == ObjectInspector.Category.MAP) {
nestedMapDepth++;
mapObjectInspector = (StandardMapObjectInspector) mapObjectInspector.getMapValueObjectInspector();
}
lastValueInspector = (PrimitiveObjectInspector) mapObjectInspector.getMapValueObjectInspector();
return parameters[0];
}
}
@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
return new NestedMapCollector();
}
// 处理明细数据
@Override
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
if (null == parameters[parameters.length - 1]) {
return;
}
NestedMapCollector aggMap = (NestedMapCollector) agg;
Object lastKey = inputsInspectors[nestedMapDepth - 1].getPrimitiveJavaObject(parameters[parameters.length - 2]);
if (null == lastKey) {
lastKey = "NULL";
}
long value = ((Number) lastValueInspector.getPrimitiveJavaObject(parameters[parameters.length - 1])).longValue();
Map<Object, Object> tmpMap = aggMap.collectMap;
for (int i = 0; i < nestedMapDepth - 1; i++) {
Object key = inputsInspectors[i].getPrimitiveJavaObject(parameters[i]);
if (null == key) {
key = "NULL";
}
if (tmpMap.containsKey(key)) {
tmpMap = (Map<Object, Object>) tmpMap.get(key);
} else {
Map<Object, Object> newMap = new HashMap<Object, Object>();
tmpMap.put(key, newMap);
tmpMap = newMap;
}
}
long defaultValue = (long) tmpMap.getOrDefault(lastKey, 0L);
tmpMap.put(lastKey, defaultValue + value);
}
// 处理中间结果
@Override
public void merge(AggregationBuffer agg, Object partial) throws HiveException {
NestedMapCollector mapAgg = (NestedMapCollector) agg;
Map<Object, Object> otherMap = (Map<Object, Object>) mapObjectInspector.getMap(partial);
mapMerge(mapAgg.collectMap, otherMap, nestedMapDepth);
}
private static void mapMerge(Map<Object, Object> collectMap, Map<Object, Object> updateMap, int depth) {
if (1 == depth) {
for (Map.Entry<Object, Object> entry : updateMap.entrySet()) {
Object key = entry.getKey();
long updateValue = (long) entry.getValue();
long defaultValue = (long) collectMap.getOrDefault(key, 0L);
collectMap.put(key, updateValue + defaultValue);
}
} else {
for (Map.Entry<Object, Object> entry : updateMap.entrySet()) {
Object key = entry.getKey();
Map<Object, Object> defaultMap;
if (collectMap.containsKey(key)) {
defaultMap = (Map<Object, Object>) collectMap.get(key);
} else {
defaultMap = new HashMap<>();
collectMap.put(key, defaultMap);
}
mapMerge(defaultMap, ((Map<Object, Object>) entry.getValue()), depth - 1);
}
}
}
@Override
public void reset(AggregationBuffer buff) throws HiveException {
((NestedMapCollector) buff).collectMap.clear();
}
// 输出最终结果
@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
NestedMapCollector mapAgg = (NestedMapCollector) agg;
return mapAgg.collectMap;
}
@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
return terminate(agg);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment