246 lines
9.4 KiB
Python
246 lines
9.4 KiB
Python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Ops for computing common window functions."""
|
|
|
|
import numpy as np
|
|
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import tensor_util
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import cond
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import nn_ops
|
|
from tensorflow.python.ops import special_math_ops
|
|
from tensorflow.python.util import dispatch
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
|
|
def _check_params(window_length, dtype):
|
|
"""Check window_length and dtype params.
|
|
|
|
Args:
|
|
window_length: A scalar value or `Tensor`.
|
|
dtype: The data type to produce. Must be a floating point type.
|
|
|
|
Returns:
|
|
window_length converted to a tensor of type int32.
|
|
|
|
Raises:
|
|
ValueError: If `dtype` is not a floating point type or window_length is not
|
|
a scalar.
|
|
"""
|
|
if not dtype.is_floating:
|
|
raise ValueError('dtype must be a floating point type. Found %s' % dtype)
|
|
window_length = ops.convert_to_tensor(window_length, dtype=dtypes.int32)
|
|
window_length.shape.assert_has_rank(0)
|
|
return window_length
|
|
|
|
|
|
@tf_export('signal.kaiser_window')
|
|
@dispatch.add_dispatch_support
|
|
def kaiser_window(window_length, beta=12., dtype=dtypes.float32, name=None):
|
|
"""Generate a [Kaiser window][kaiser].
|
|
|
|
Args:
|
|
window_length: A scalar `Tensor` indicating the window length to generate.
|
|
beta: Beta parameter for Kaiser window, see reference below.
|
|
dtype: The data type to produce. Must be a floating point type.
|
|
name: An optional name for the operation.
|
|
|
|
Returns:
|
|
A `Tensor` of shape `[window_length]` of type `dtype`.
|
|
|
|
[kaiser]:
|
|
https://docs.scipy.org/doc/numpy/reference/generated/numpy.kaiser.html
|
|
"""
|
|
with ops.name_scope(name, 'kaiser_window'):
|
|
window_length = _check_params(window_length, dtype)
|
|
window_length_const = tensor_util.constant_value(window_length)
|
|
if window_length_const == 1:
|
|
return array_ops.ones([1], dtype=dtype)
|
|
# tf.range does not support float16 so we work with float32 initially.
|
|
halflen_float = (
|
|
math_ops.cast(window_length, dtype=dtypes.float32) - 1.0) / 2.0
|
|
arg = math_ops.range(-halflen_float, halflen_float + 0.1,
|
|
dtype=dtypes.float32)
|
|
# Convert everything into given dtype which can be float16.
|
|
arg = math_ops.cast(arg, dtype=dtype)
|
|
beta = math_ops.cast(beta, dtype=dtype)
|
|
one = math_ops.cast(1.0, dtype=dtype)
|
|
halflen_float = math_ops.cast(halflen_float, dtype=dtype)
|
|
num = beta * math_ops.sqrt(nn_ops.relu(
|
|
one - math_ops.square(arg / halflen_float)))
|
|
window = math_ops.exp(num - beta) * (
|
|
special_math_ops.bessel_i0e(num) / special_math_ops.bessel_i0e(beta))
|
|
return window
|
|
|
|
|
|
@tf_export('signal.kaiser_bessel_derived_window')
|
|
@dispatch.add_dispatch_support
|
|
def kaiser_bessel_derived_window(window_length, beta=12.,
|
|
dtype=dtypes.float32, name=None):
|
|
"""Generate a [Kaiser Bessel derived window][kbd].
|
|
|
|
Args:
|
|
window_length: A scalar `Tensor` indicating the window length to generate.
|
|
beta: Beta parameter for Kaiser window.
|
|
dtype: The data type to produce. Must be a floating point type.
|
|
name: An optional name for the operation.
|
|
|
|
Returns:
|
|
A `Tensor` of shape `[window_length]` of type `dtype`.
|
|
|
|
[kbd]:
|
|
https://en.wikipedia.org/wiki/Kaiser_window#Kaiser%E2%80%93Bessel-derived_(KBD)_window
|
|
"""
|
|
with ops.name_scope(name, 'kaiser_bessel_derived_window'):
|
|
window_length = _check_params(window_length, dtype)
|
|
halflen = window_length // 2
|
|
kaiserw = kaiser_window(halflen + 1, beta, dtype=dtype)
|
|
kaiserw_csum = math_ops.cumsum(kaiserw)
|
|
halfw = math_ops.sqrt(kaiserw_csum[:-1] / kaiserw_csum[-1])
|
|
window = array_ops.concat((halfw, halfw[::-1]), axis=0)
|
|
return window
|
|
|
|
|
|
@tf_export('signal.vorbis_window')
|
|
@dispatch.add_dispatch_support
|
|
def vorbis_window(window_length, dtype=dtypes.float32, name=None):
|
|
"""Generate a [Vorbis power complementary window][vorbis].
|
|
|
|
Args:
|
|
window_length: A scalar `Tensor` indicating the window length to generate.
|
|
dtype: The data type to produce. Must be a floating point type.
|
|
name: An optional name for the operation.
|
|
|
|
Returns:
|
|
A `Tensor` of shape `[window_length]` of type `dtype`.
|
|
|
|
[vorbis]:
|
|
https://en.wikipedia.org/wiki/Modified_discrete_cosine_transform#Window_functions
|
|
"""
|
|
with ops.name_scope(name, 'vorbis_window'):
|
|
window_length = _check_params(window_length, dtype)
|
|
arg = math_ops.cast(math_ops.range(window_length), dtype=dtype)
|
|
window = math_ops.sin(np.pi / 2.0 * math_ops.pow(math_ops.sin(
|
|
np.pi / math_ops.cast(window_length, dtype=dtype) *
|
|
(arg + 0.5)), 2.0))
|
|
return window
|
|
|
|
|
|
@tf_export('signal.hann_window')
|
|
@dispatch.add_dispatch_support
|
|
def hann_window(window_length, periodic=True, dtype=dtypes.float32, name=None):
|
|
"""Generate a [Hann window][hann].
|
|
|
|
Args:
|
|
window_length: A scalar `Tensor` indicating the window length to generate.
|
|
periodic: A bool `Tensor` indicating whether to generate a periodic or
|
|
symmetric window. Periodic windows are typically used for spectral
|
|
analysis while symmetric windows are typically used for digital
|
|
filter design.
|
|
dtype: The data type to produce. Must be a floating point type.
|
|
name: An optional name for the operation.
|
|
|
|
Returns:
|
|
A `Tensor` of shape `[window_length]` of type `dtype`.
|
|
|
|
Raises:
|
|
ValueError: If `dtype` is not a floating point type.
|
|
|
|
[hann]: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
|
|
"""
|
|
return _raised_cosine_window(name, 'hann_window', window_length, periodic,
|
|
dtype, 0.5, 0.5)
|
|
|
|
|
|
@tf_export('signal.hamming_window')
|
|
@dispatch.add_dispatch_support
|
|
def hamming_window(window_length, periodic=True, dtype=dtypes.float32,
|
|
name=None):
|
|
"""Generate a [Hamming][hamming] window.
|
|
|
|
Args:
|
|
window_length: A scalar `Tensor` indicating the window length to generate.
|
|
periodic: A bool `Tensor` indicating whether to generate a periodic or
|
|
symmetric window. Periodic windows are typically used for spectral
|
|
analysis while symmetric windows are typically used for digital
|
|
filter design.
|
|
dtype: The data type to produce. Must be a floating point type.
|
|
name: An optional name for the operation.
|
|
|
|
Returns:
|
|
A `Tensor` of shape `[window_length]` of type `dtype`.
|
|
|
|
Raises:
|
|
ValueError: If `dtype` is not a floating point type.
|
|
|
|
[hamming]:
|
|
https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows
|
|
"""
|
|
return _raised_cosine_window(name, 'hamming_window', window_length, periodic,
|
|
dtype, 0.54, 0.46)
|
|
|
|
|
|
def _raised_cosine_window(name, default_name, window_length, periodic,
|
|
dtype, a, b):
|
|
"""Helper function for computing a raised cosine window.
|
|
|
|
Args:
|
|
name: Name to use for the scope.
|
|
default_name: Default name to use for the scope.
|
|
window_length: A scalar `Tensor` or integer indicating the window length.
|
|
periodic: A bool `Tensor` indicating whether to generate a periodic or
|
|
symmetric window.
|
|
dtype: A floating point `DType`.
|
|
a: The alpha parameter to the raised cosine window.
|
|
b: The beta parameter to the raised cosine window.
|
|
|
|
Returns:
|
|
A `Tensor` of shape `[window_length]` of type `dtype`.
|
|
|
|
Raises:
|
|
ValueError: If `dtype` is not a floating point type or `window_length` is
|
|
not scalar or `periodic` is not scalar.
|
|
"""
|
|
if not dtype.is_floating:
|
|
raise ValueError('dtype must be a floating point type. Found %s' % dtype)
|
|
|
|
with ops.name_scope(name, default_name, [window_length, periodic]):
|
|
window_length = ops.convert_to_tensor(window_length, dtype=dtypes.int32,
|
|
name='window_length')
|
|
window_length.shape.assert_has_rank(0)
|
|
window_length_const = tensor_util.constant_value(window_length)
|
|
if window_length_const == 1:
|
|
return array_ops.ones([1], dtype=dtype)
|
|
periodic = math_ops.cast(
|
|
ops.convert_to_tensor(periodic, dtype=dtypes.bool, name='periodic'),
|
|
dtypes.int32)
|
|
periodic.shape.assert_has_rank(0)
|
|
even = 1 - math_ops.mod(window_length, 2)
|
|
|
|
n = math_ops.cast(window_length + periodic * even - 1, dtype=dtype)
|
|
count = math_ops.cast(math_ops.range(window_length), dtype)
|
|
cos_arg = constant_op.constant(2 * np.pi, dtype=dtype) * count / n
|
|
|
|
if window_length_const is not None:
|
|
return math_ops.cast(a - b * math_ops.cos(cos_arg), dtype=dtype)
|
|
return cond.cond(
|
|
math_ops.equal(window_length, 1),
|
|
lambda: array_ops.ones([window_length], dtype=dtype),
|
|
lambda: math_ops.cast(a - b * math_ops.cos(cos_arg), dtype=dtype))
|