rsr.rsr.classify_samples_with_indices

rsr.rsr.classify_samples_with_indices(samples, upper_refs, lower_refs, *, return_masks=False)[source]

Classify samples as upper, lower, or unknown using subset checks, and return indices for each class.

Parameters:
  • samples (Tensor) – (n_sample, n_var, n_state) binary tensor

  • upper_refs (List[Tensor]) – list of ref tensors, each (n_var, n_state) or (n_var+1, n_state)

  • lower_refs (List[Tensor]) – list of ref tensors, each (n_var, n_state) or (n_var+1, n_state)

  • return_masks (bool) – if True, also return boolean masks per class

Return type:

Dict[str, Any]

Returns:

A dictionary of the form:

{
  'upper': int,
  'lower': int,
  'unknown': int,
  'idx_upper': LongTensor[ns],
  'idx_lower': LongTensor[nf],
  'idx_unknown': LongTensor[nu],
  # optionally (if return_masks=True):
  'mask_upper': BoolTensor[n_sample],
  'mask_lower': BoolTensor[n_sample],
  'mask_unknown': BoolTensor[n_sample],
}