检查和索引numpy数组中的非唯一/重复值
问题内容:
我有一个traced_descIDs
包含对象ID的数组,我想确定哪些项目在此数组中不是唯一的。然后,对于每个唯一的重复(仔细的)ID,我需要确定traced_descIDs
与之关联的索引。
例如,如果我们在此处采用traced_descID,则希望发生以下过程:
traced_descIDs = [1, 345, 23, 345, 90, 1]
dupIds = [1, 345]
dupInds = [[0,5],[1,3]]
我目前正在通过以下方式找出哪些对象具有多个条目:
mentions = np.array([len(np.argwhere( traced_descIDs == i)) for i in traced_descIDs])
dupMask = (mentions > 1)
但是,这花费了len( traced_descIDs )
大约15万的时间。有没有更快的方法来达到相同的结果?
任何帮助,不胜感激。干杯。
问题答案:
虽然字典是O(n),但Python对象的开销有时使使用numpy函数更方便,这些函数使用排序功能并且是O(n * log n)。在您的情况下,起点将是:
a = [1, 345, 23, 345, 90, 1]
unq, unq_idx, unq_cnt = np.unique(a, return_inverse=True, return_counts=True)
如果您使用的numpy版本早于1.9,则最后一行必须是:
unq, unq_idx = np.unique(a, return_inverse=True)
unq_cnt = np.bincount(unq_idx)
我们创建的三个数组的内容是:
>>> unq
array([ 1, 23, 90, 345])
>>> unq_idx
array([0, 3, 1, 3, 2, 0])
>>> unq_cnt
array([2, 1, 1, 2])
获取重复项:
cnt_mask = unq_cnt > 1
dup_ids = unq[cnt_mask]
>>> dup_ids
array([ 1, 345])
获取索引要稍微复杂一点,但是非常简单:
cnt_idx, = np.nonzero(cnt_mask)
idx_mask = np.in1d(unq_idx, cnt_idx)
idx_idx, = np.nonzero(idx_mask)
srt_idx = np.argsort(unq_idx[idx_mask])
dup_idx = np.split(idx_idx[srt_idx], np.cumsum(unq_cnt[cnt_mask])[:-1])
>>> dup_idx
[array([0, 5]), array([1, 3])]