NumPyJuniorCoding

Как работает np.where() и каковы его типичные сценарии использования?

np.where(condition, x, y) возвращает элементы из x там, где condition True, из y — где False. С одним аргументом возвращает индексы ненулевых элементов. Используется для векторизованных условий без Python-циклов.

np.where() в NumPy

np.where — векторизованная условная операция. Существует в двух формах: с тремя аргументами (условие + два значения) и с одним аргументом (поиск индексов).

Базовый синтаксис

import numpy as np

# Форма 1: np.where(condition, x, y)
# Возвращает массив той же формы: x[i] если condition[i] True, иначе y[i]
a = np.array([1, -2, 3, -4, 5])
result = np.where(a > 0, a, 0)   # ReLU: отрицательные -> 0
print(result)  # [1 0 3 0 5]

# Форма 2: np.where(condition)
# Возвращает TUPLE индексов (аналог np.nonzero)
indices = np.where(a > 0)        # (array([0, 2, 4]),)
print(a[indices])                 # [1 3 5]

# Для 2D-массива возвращает кортеж из N массивов (по одному на ось)
M = np.array([[1, -2], [-3, 4]])
rows, cols = np.where(M > 0)
print(rows, cols)  # [0 1] [0 1] — позиции положительных элементов

Практические сценарии

# 1. Замена аномальных значений (clipping с условием)
temperatures = np.array([22.5, 999.0, 18.3, -999.0, 25.1])  # 999 = sensor error
cleaned = np.where(
    np.abs(temperatures) > 100,
    np.nan,
    temperatures
)
print(cleaned)  # [ 22.5  nan  18.3  nan  25.1]

# 2. Бинаризация (threshold)
scores = np.array([0.2, 0.7, 0.45, 0.9, 0.3])
labels = np.where(scores >= 0.5, 1, 0)
print(labels)  # [0 1 0 1 0]

# 3. Множественные условия через np.select (обобщение np.where)
ages = np.array([5, 15, 25, 65, 80])
conditions = [
    ages < 13,
    ages < 18,
    ages < 65,
]
choices = ['child', 'teen', 'adult']
groups = np.select(conditions, choices, default='senior')
print(groups)  # ['child' 'teen' 'adult' 'senior' 'senior']

# 4. Векторизованный lookup (замена dict.get)
category_ids = np.array([0, 2, 1, 3, 1])
category_names = np.array(['cat', 'dog', 'bird', 'fish'])
names = category_names[category_ids]  # fancy indexing
print(names)  # ['cat' 'bird' 'dog' 'fish' 'bird']

# 5. Условное присваивание in-place
prices = np.array([100.0, 50.0, 200.0, 30.0])
# Скидка 10% для товаров дороже 80
prices[prices > 80] *= 0.9   # через boolean indexing + in-place
# или через where:
new_prices = np.where(prices > 80, prices * 0.9, prices)

# 6. Нахождение ближайшего значения
bins = np.array([0, 10, 20, 30, 40, 50])
val = 23
idx = np.where(bins <= val)[0][-1]  # последний bin <= val
print(f"Value {val} falls in bin starting at {bins[idx]}")  # bin 20

# 7. Маскирование в pandas-стиле без pandas
X = np.random.randn(100, 5).astype(np.float32)
mask = np.where(np.isnan(X))    # найти NaN
X[mask] = 0.0                    # заменить нулями

Производительность vs Python-цикл

import time

data = np.random.randn(1_000_000).astype(np.float32)

# Python цикл: медленно
start = time.perf_counter()
result_loop = [x if x > 0 else 0.0 for x in data]
print(f"Loop: {time.perf_counter() - start:.3f}s")

# np.where: быстро
start = time.perf_counter()
result_np = np.where(data > 0, data, 0.0)
print(f"np.where: {time.perf_counter() - start:.4f}s")
# Loop: ~0.3s, np.where: ~0.003s — в 100 раз быстрее

Подводные камни

  • Оба аргумента x и y вычисляются всегда, даже если они не нужны — np.where(mask, 1/a, 0) при a=0 всё равно вызовет деление на ноль для всех элементов.
  • С одним аргументом np.where(condition) возвращает кортеж массивов, а не один массив — распаковка idx = np.where(a > 0)[0] обязательна для 1D.
  • Тип результата определяется по x и y через broadcasting: np.where(cond, int_arr, float_arr) даёт float64 — может неожиданно поменять dtype.
  • На очень больших массивах np.where создаёт временный булевый массив той же формы — потребление памяти увеличивается на N байт (N = len(array)).
  • np.where не ленивый — для сложных условий с несколькими порогами np.select чище и эффективнее цепочки вложенных np.where.
  • При работе с NaN: np.where(np.isnan(x), 0, x) работает корректно, но np.where(x > 0, x, np.nan) меняет dtype на float64 даже для float32-массива.

Common mistakes

  • Объяснять where conditional selection только синтаксисом без shape, dtype, состояния или режима выполнения.
  • Игнорировать leakage, воспроизводимость, пустые входы и скрытые копии данных.
  • Не проверять production-симптомы: latency, память, ретраи, дрейф качества и несовпадение версий.

What the interviewer is testing

  • Может ли связать where conditional selection с реальным контрактом входов и выходов.
  • Упоминает ли тесты, метрики, reproducibility и диагностику ошибок.
  • Видит ли различие между demo-кодом в ноутбуке и production-пайплайном.

Sources

Related topics