Source code for opensearchpy.helpers.aggs

# SPDX-License-Identifier: Apache-2.0
#
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
#
# Modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
#
#  Licensed to Elasticsearch B.V. under one or more contributor
#  license agreements. See the NOTICE file distributed with
#  this work for additional information regarding copyright
#  ownership. Elasticsearch B.V. licenses this file to you under
#  the Apache License, Version 2.0 (the "License"); you may
#  not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
# 	http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing,
#  software distributed under the License is distributed on an
#  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
#  KIND, either express or implied.  See the License for the
#  specific language governing permissions and limitations
#  under the License.


import collections.abc as collections_abc
from typing import Any, Optional

from .response.aggs import AggResponse, BucketData, FieldBucketData, TopHitsData
from .utils import DslBase


def A(  # pylint: disable=invalid-name
    name_or_agg: Any, filter: Any = None, **params: Any
) -> Any:
    if filter is not None:
        if name_or_agg != "filter":
            raise ValueError(
                "Aggregation %r doesn't accept positional argument 'filter'."
                % name_or_agg
            )
        params["filter"] = filter

    # {"terms": {"field": "tags"}, "aggs": {...}}
    if isinstance(name_or_agg, collections_abc.Mapping):
        if params:
            raise ValueError("A() cannot accept parameters when passing in a dict.")
        # copy to avoid modifying in-place
        agg = name_or_agg.copy()  # type: ignore
        # pop out nested aggs
        aggs = agg.pop("aggs", None)
        # pop out meta data
        meta = agg.pop("meta", None)
        # should be {"terms": {"field": "tags"}}
        if len(agg) != 1:
            raise ValueError(
                'A() can only accept dict with an aggregation ({"terms": {...}}). '
                "Instead it got (%r)" % name_or_agg
            )
        agg_type, params = agg.popitem()
        if aggs:
            params = params.copy()
            params["aggs"] = aggs
        if meta:
            params = params.copy()
            params["meta"] = meta
        return Agg.get_dsl_class(agg_type)(_expand__to_dot=False, **params)

    # Terms(...) just return the nested agg
    elif isinstance(name_or_agg, Agg):
        if params:
            raise ValueError(
                "A() cannot accept parameters when passing in an Agg object."
            )
        return name_or_agg

    # "terms", field="tags"
    return Agg.get_dsl_class(name_or_agg)(**params)


[docs]class Agg(DslBase): _type_name: str = "agg" _type_shortcut = staticmethod(A) name: Optional[str] = None def __contains__(self, key: Any) -> bool: return False
[docs] def to_dict(self) -> Any: d = super().to_dict() if "meta" in d[self.name]: d["meta"] = d[self.name].pop("meta") return d
def result(self, search: Any, data: Any) -> Any: return AggResponse(self, search, data)
class AggBase: _param_defs = { "aggs": {"type": "agg", "hash": True}, } def __contains__(self: Any, key: Any) -> bool: return key in self._params.get("aggs", {}) def __getitem__(self: Any, agg_name: Any) -> Any: agg = self._params.setdefault("aggs", {})[agg_name] # propagate KeyError # make sure we're not mutating a shared state - whenever accessing a # bucket, return a shallow copy of it to be safe if isinstance(agg, Bucket): agg = A(agg.name, **agg._params) # be sure to store the copy so any modifications to it will affect us self._params["aggs"][agg_name] = agg return agg def __setitem__(self: Any, agg_name: str, agg: Any) -> None: self.aggs[agg_name] = A(agg) def __iter__(self: Any) -> Any: return iter(self.aggs) def _agg( self: Any, bucket: Any, name: Any, agg_type: Any, *args: Any, **params: Any ) -> Any: agg = self[name] = A(agg_type, *args, **params) # For chaining - when creating new buckets return them... if bucket: return agg # otherwise return self._base so we can keep chaining else: return self._base def metric(self: Any, name: Any, agg_type: Any, *args: Any, **params: Any) -> Any: return self._agg(False, name, agg_type, *args, **params) def bucket(self: Any, name: Any, agg_type: Any, *args: Any, **params: Any) -> Any: return self._agg(True, name, agg_type, *args, **params) def pipeline(self: Any, name: Any, agg_type: Any, *args: Any, **params: Any) -> Any: return self._agg(False, name, agg_type, *args, **params) def result(self: Any, search: Any, data: Any) -> Any: return BucketData(self, search, data) class Bucket(AggBase, Agg): def __init__(self, **params: Any) -> None: super().__init__(**params) # remember self for chaining self._base = self def to_dict(self) -> Any: d = super(AggBase, self).to_dict() if "aggs" in d[self.name]: d["aggs"] = d[self.name].pop("aggs") return d class Filter(Bucket): name: Optional[str] = "filter" _param_defs = { "filter": {"type": "query"}, "aggs": {"type": "agg", "hash": True}, } def __init__(self, filter: Any = None, **params: Any) -> None: if filter is not None: params["filter"] = filter super().__init__(**params) def to_dict(self) -> Any: d = super().to_dict() d[self.name].update(d[self.name].pop("filter", {})) return d class Pipeline(Agg): pass # bucket aggregations class Filters(Bucket): name: str = "filters" _param_defs = { "filters": {"type": "query", "hash": True}, "aggs": {"type": "agg", "hash": True}, } class Children(Bucket): name = "children" class Parent(Bucket): name = "parent" class DateHistogram(Bucket): name = "date_histogram" def result(self, search: Any, data: Any) -> Any: return FieldBucketData(self, search, data) class AutoDateHistogram(DateHistogram): name = "auto_date_histogram" class DateRange(Bucket): name = "date_range" class GeoDistance(Bucket): name = "geo_distance" class GeohashGrid(Bucket): name = "geohash_grid" class GeotileGrid(Bucket): name = "geotile_grid" class GeoCentroid(Bucket): name = "geo_centroid" class Global(Bucket): name = "global" class Histogram(Bucket): name = "histogram" def result(self, search: Any, data: Any) -> Any: return FieldBucketData(self, search, data) class IPRange(Bucket): name = "ip_range" class Missing(Bucket): name = "missing" class Nested(Bucket): name = "nested" class Range(Bucket): name = "range" class RareTerms(Bucket): name = "rare_terms" def result(self, search: Any, data: Any) -> Any: return FieldBucketData(self, search, data) class ReverseNested(Bucket): name = "reverse_nested" class SignificantTerms(Bucket): name = "significant_terms" class SignificantText(Bucket): name = "significant_text" class Terms(Bucket): name = "terms" def result(self, search: Any, data: Any) -> Any: return FieldBucketData(self, search, data) class Sampler(Bucket): name = "sampler" class DiversifiedSampler(Bucket): name = "diversified_sampler" class Composite(Bucket): name = "composite" _param_defs = { "sources": {"type": "agg", "hash": True, "multi": True}, "aggs": {"type": "agg", "hash": True}, } class VariableWidthHistogram(Bucket): name = "variable_width_histogram" def result(self, search: Any, data: Any) -> Any: return FieldBucketData(self, search, data) class MultiTerms(Bucket): name = "multi_terms" # metric aggregations class TopHits(Agg): name = "top_hits" def result(self, search: Any, data: Any) -> Any: return TopHitsData(self, search, data) class Avg(Agg): name = "avg" class WeightedAvg(Agg): name = "weighted_avg" class Cardinality(Agg): name = "cardinality" class ExtendedStats(Agg): name = "extended_stats" class Boxplot(Agg): name = "boxplot" class GeoBounds(Agg): name = "geo_bounds" class Max(Agg): name = "max" class MedianAbsoluteDeviation(Agg): name = "median_absolute_deviation" class Min(Agg): name = "min" class Percentiles(Agg): name = "percentiles" class PercentileRanks(Agg): name = "percentile_ranks" class ScriptedMetric(Agg): name = "scripted_metric" class Stats(Agg): name = "stats" class Sum(Agg): name = "sum" class TTest(Agg): name = "t_test" class ValueCount(Agg): name = "value_count" # pipeline aggregations class AvgBucket(Pipeline): name = "avg_bucket" class BucketScript(Pipeline): name = "bucket_script" class BucketSelector(Pipeline): name = "bucket_selector" class CumulativeSum(Pipeline): name = "cumulative_sum" class CumulativeCardinality(Pipeline): name = "cumulative_cardinality" class Derivative(Pipeline): name = "derivative" class ExtendedStatsBucket(Pipeline): name = "extended_stats_bucket" class Inference(Pipeline): name = "inference" class MaxBucket(Pipeline): name = "max_bucket" class MinBucket(Pipeline): name = "min_bucket" class MovingFn(Pipeline): name = "moving_fn" class MovingAvg(Pipeline): name = "moving_avg" class MovingPercentiles(Pipeline): name = "moving_percentiles" class Normalize(Pipeline): name = "normalize" class PercentilesBucket(Pipeline): name = "percentiles_bucket" class SerialDiff(Pipeline): name = "serial_diff" class StatsBucket(Pipeline): name = "stats_bucket" class SumBucket(Pipeline): name = "sum_bucket" class BucketSort(Pipeline): name = "bucket_sort"