package cn.com.duiba.udf;

import com.google.common.collect.Lists;
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.hive.ql.exec.UDF;

import java.util.List;

/**
 * 求向量的平均
 */
public class VectorAvgUDF extends UDF {

    public String evaluate(String str, String delimiter, String separator) {

        if (StringUtils.isBlank(str) || StringUtils.isBlank(delimiter) || StringUtils.isBlank(separator)) {
            return null;
        }
        String[] strArray = StringUtils.split(str, delimiter);
        if (ArrayUtils.isEmpty(strArray)) {
            return null;
        }

        long maxLength = 0L;
        //分割向量
        List<String[]> vectorList = Lists.newArrayList();
        for (String s : strArray) {
            String[] split = StringUtils.split(s, separator);
            maxLength = getMaxLength(maxLength, split);
            vectorList.add(split);
        }

        if ( vectorList.size() == 0  ) {
            return null;
        }

        //计算平均值
        List<Double> avgList = Lists.newArrayList();
        for (int i = 0; i < maxLength; i++) {
            Double sum = 0D;
            for (String[] vector : vectorList) {
                if (vector != null && vector.length > i) {
                    sum += Double.valueOf(vector[i]);
                }
            }
            avgList.add(sum / vectorList.size());
        }
        return StringUtils.join(avgList, separator);
    }

    private long getMaxLength(long maxLength, String[] split) {
        if (split == null) {
            return maxLength;
        }
        return split.length > maxLength ? split.length : maxLength;
    }
}
