Browse Source

Add OpenImageDenoise thirdparty library

JFonS 4 years ago
parent
commit
ad8abef74c
100 changed files with 20939 additions and 0 deletions
  1. 118 0
      modules/denoise/SCsub
  2. 15 0
      modules/denoise/config.py
  3. 69 0
      modules/denoise/denoise_wrapper.cpp
  4. 38 0
      modules/denoise/denoise_wrapper.h
  5. 65 0
      modules/denoise/lightmap_denoiser.cpp
  6. 56 0
      modules/denoise/lightmap_denoiser.h
  7. 40 0
      modules/denoise/register_types.cpp
  8. 37 0
      modules/denoise/register_types.h
  9. 68 0
      modules/denoise/resource_to_cpp.py
  10. 31 0
      thirdparty/README.md
  11. 202 0
      thirdparty/oidn/LICENSE.txt
  12. 52 0
      thirdparty/oidn/common/barrier.h
  13. 45 0
      thirdparty/oidn/common/exception.h
  14. 114 0
      thirdparty/oidn/common/platform.cpp
  15. 131 0
      thirdparty/oidn/common/platform.h
  16. 163 0
      thirdparty/oidn/common/ref.h
  17. 83 0
      thirdparty/oidn/common/tensor.cpp
  18. 66 0
      thirdparty/oidn/common/tensor.h
  19. 297 0
      thirdparty/oidn/common/thread.cpp
  20. 202 0
      thirdparty/oidn/common/thread.h
  21. 49 0
      thirdparty/oidn/common/timer.h
  22. 408 0
      thirdparty/oidn/core/api.cpp
  23. 535 0
      thirdparty/oidn/core/autoencoder.cpp
  24. 120 0
      thirdparty/oidn/core/autoencoder.h
  25. 75 0
      thirdparty/oidn/core/buffer.h
  26. 136 0
      thirdparty/oidn/core/common.h
  27. 238 0
      thirdparty/oidn/core/device.cpp
  28. 102 0
      thirdparty/oidn/core/device.h
  29. 27 0
      thirdparty/oidn/core/filter.cpp
  30. 52 0
      thirdparty/oidn/core/filter.h
  31. 111 0
      thirdparty/oidn/core/image.h
  32. 232 0
      thirdparty/oidn/core/input_reorder.h
  33. 78 0
      thirdparty/oidn/core/math.h
  34. 436 0
      thirdparty/oidn/core/network.cpp
  35. 112 0
      thirdparty/oidn/core/network.h
  36. 142 0
      thirdparty/oidn/core/node.h
  37. 126 0
      thirdparty/oidn/core/output_reorder.h
  38. 103 0
      thirdparty/oidn/core/transfer_function.cpp
  39. 201 0
      thirdparty/oidn/core/transfer_function.h
  40. 92 0
      thirdparty/oidn/core/upsample.h
  41. 99 0
      thirdparty/oidn/core/weights_reorder.h
  42. 214 0
      thirdparty/oidn/include/OpenImageDenoise/oidn.h
  43. 468 0
      thirdparty/oidn/include/OpenImageDenoise/oidn.hpp
  44. 23 0
      thirdparty/oidn/include/OpenImageDenoise/version.h
  45. 214 0
      thirdparty/oidn/mkl-dnn/LICENSE
  46. 1771 0
      thirdparty/oidn/mkl-dnn/include/mkldnn.h
  47. 2615 0
      thirdparty/oidn/mkl-dnn/include/mkldnn.hpp
  48. 98 0
      thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h
  49. 1415 0
      thirdparty/oidn/mkl-dnn/include/mkldnn_types.h
  50. 32 0
      thirdparty/oidn/mkl-dnn/include/mkldnn_version.h
  51. 32 0
      thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in
  52. 104 0
      thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp
  53. 240 0
      thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp
  54. 550 0
      thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp
  55. 86 0
      thirdparty/oidn/mkl-dnn/src/common/concat.cpp
  56. 211 0
      thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp
  57. 200 0
      thirdparty/oidn/mkl-dnn/src/common/convolution.cpp
  58. 56 0
      thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp
  59. 348 0
      thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp
  60. 188 0
      thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp
  61. 293 0
      thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp
  62. 84 0
      thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp
  63. 161 0
      thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp
  64. 75 0
      thirdparty/oidn/mkl-dnn/src/common/engine.cpp
  65. 119 0
      thirdparty/oidn/mkl-dnn/src/common/engine.hpp
  66. 106 0
      thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp
  67. 56 0
      thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp
  68. 321 0
      thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp
  69. 91 0
      thirdparty/oidn/mkl-dnn/src/common/lrn.cpp
  70. 170 0
      thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp
  71. 280 0
      thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp
  72. 238 0
      thirdparty/oidn/mkl-dnn/src/common/memory.cpp
  73. 63 0
      thirdparty/oidn/mkl-dnn/src/common/memory.hpp
  74. 212 0
      thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp
  75. 400 0
      thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp
  76. 295 0
      thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp
  77. 131 0
      thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp
  78. 365 0
      thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp
  79. 115 0
      thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp
  80. 277 0
      thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp
  81. 77 0
      thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp
  82. 193 0
      thirdparty/oidn/mkl-dnn/src/common/nstl.hpp
  83. 114 0
      thirdparty/oidn/mkl-dnn/src/common/pooling.cpp
  84. 238 0
      thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp
  85. 103 0
      thirdparty/oidn/mkl-dnn/src/common/primitive.cpp
  86. 76 0
      thirdparty/oidn/mkl-dnn/src/common/primitive.hpp
  87. 290 0
      thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp
  88. 183 0
      thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp
  89. 78 0
      thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp
  90. 174 0
      thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp
  91. 90 0
      thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp
  92. 68 0
      thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp
  93. 89 0
      thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp
  94. 79 0
      thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp
  95. 59 0
      thirdparty/oidn/mkl-dnn/src/common/query.cpp
  96. 68 0
      thirdparty/oidn/mkl-dnn/src/common/reorder.cpp
  97. 85 0
      thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp
  98. 400 0
      thirdparty/oidn/mkl-dnn/src/common/rnn.cpp
  99. 280 0
      thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp
  100. 112 0
      thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp

+ 118 - 0
modules/denoise/SCsub

@@ -0,0 +1,118 @@
+#!/usr/bin/env python
+
+import resource_to_cpp
+
+Import("env")
+Import("env_modules")
+
+env_oidn = env_modules.Clone()
+
+# Thirdparty source files
+thirdparty_dir = "#thirdparty/oidn/"
+thirdparty_sources = [
+    "core/api.cpp",
+    "core/device.cpp",
+    "core/filter.cpp",
+    "core/network.cpp",
+    "core/autoencoder.cpp",
+    "core/transfer_function.cpp",
+    "weights/rtlightmap_hdr.gen.cpp",
+    "mkl-dnn/src/common/batch_normalization.cpp",
+    "mkl-dnn/src/common/concat.cpp",
+    "mkl-dnn/src/common/convolution.cpp",
+    "mkl-dnn/src/common/convolution_pd.cpp",
+    "mkl-dnn/src/common/deconvolution.cpp",
+    "mkl-dnn/src/common/eltwise.cpp",
+    "mkl-dnn/src/common/engine.cpp",
+    "mkl-dnn/src/common/inner_product.cpp",
+    "mkl-dnn/src/common/inner_product_pd.cpp",
+    "mkl-dnn/src/common/lrn.cpp",
+    "mkl-dnn/src/common/memory.cpp",
+    "mkl-dnn/src/common/memory_desc_wrapper.cpp",
+    "mkl-dnn/src/common/mkldnn_debug.cpp",
+    "mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp",
+    "mkl-dnn/src/common/pooling.cpp",
+    "mkl-dnn/src/common/primitive.cpp",
+    "mkl-dnn/src/common/primitive_attr.cpp",
+    "mkl-dnn/src/common/primitive_desc.cpp",
+    "mkl-dnn/src/common/primitive_exec_types.cpp",
+    "mkl-dnn/src/common/primitive_iterator.cpp",
+    "mkl-dnn/src/common/query.cpp",
+    "mkl-dnn/src/common/reorder.cpp",
+    "mkl-dnn/src/common/rnn.cpp",
+    "mkl-dnn/src/common/scratchpad.cpp",
+    "mkl-dnn/src/common/shuffle.cpp",
+    "mkl-dnn/src/common/softmax.cpp",
+    "mkl-dnn/src/common/stream.cpp",
+    "mkl-dnn/src/common/sum.cpp",
+    "mkl-dnn/src/common/utils.cpp",
+    "mkl-dnn/src/common/verbose.cpp",
+    "mkl-dnn/src/cpu/cpu_barrier.cpp",
+    "mkl-dnn/src/cpu/cpu_concat.cpp",
+    "mkl-dnn/src/cpu/cpu_engine.cpp",
+    "mkl-dnn/src/cpu/cpu_memory.cpp",
+    "mkl-dnn/src/cpu/cpu_reducer.cpp",
+    "mkl-dnn/src/cpu/cpu_reorder.cpp",
+    "mkl-dnn/src/cpu/cpu_sum.cpp",
+    "mkl-dnn/src/cpu/jit_avx2_conv_kernel_f32.cpp",
+    "mkl-dnn/src/cpu/jit_avx2_convolution.cpp",
+    "mkl-dnn/src/cpu/jit_avx512_common_conv_kernel.cpp",
+    "mkl-dnn/src/cpu/jit_avx512_common_conv_winograd_kernel_f32.cpp",
+    "mkl-dnn/src/cpu/jit_avx512_common_convolution.cpp",
+    "mkl-dnn/src/cpu/jit_avx512_common_convolution_winograd.cpp",
+    "mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_2x3.cpp",
+    "mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3.cpp",
+    "mkl-dnn/src/cpu/jit_avx512_core_fp32_wino_conv_4x3_kernel.cpp",
+    "mkl-dnn/src/cpu/jit_sse42_conv_kernel_f32.cpp",
+    "mkl-dnn/src/cpu/jit_sse42_convolution.cpp",
+    "mkl-dnn/src/cpu/jit_transpose_src_utils.cpp",
+    "mkl-dnn/src/cpu/jit_uni_eltwise.cpp",
+    "mkl-dnn/src/cpu/jit_uni_pool_kernel_f32.cpp",
+    "mkl-dnn/src/cpu/jit_uni_pooling.cpp",
+    "mkl-dnn/src/cpu/jit_uni_reorder.cpp",
+    "mkl-dnn/src/cpu/jit_uni_reorder_utils.cpp",
+    "mkl-dnn/src/cpu/jit_utils/jit_utils.cpp",
+    "mkl-dnn/src/cpu/jit_utils/jitprofiling/jitprofiling.c",
+    "common/platform.cpp",
+    "common/thread.cpp",
+    "common/tensor.cpp",
+]
+thirdparty_sources = [thirdparty_dir + file for file in thirdparty_sources]
+
+thirdparty_include_dirs = [
+    "",
+    "include",
+    "mkl-dnn/include",
+    "mkl-dnn/src",
+    "mkl-dnn/src/common",
+    "mkl-dnn/src/cpu/xbyak",
+    "mkl-dnn/src/cpu",
+]
+thirdparty_include_dirs = [thirdparty_dir + file for file in thirdparty_include_dirs]
+
+
+env_oidn.Prepend(CPPPATH=thirdparty_include_dirs)
+env_oidn.Append(
+    CPPDEFINES=[
+        "MKLDNN_THR=MKLDNN_THR_SEQ",
+        "OIDN_STATIC_LIB",
+        "__STDC_CONSTANT_MACROS",
+        "__STDC_LIMIT_MACROS",
+        "DISABLE_VERBOSE",
+        "MKLDNN_ENABLE_CONCURRENT_EXEC",
+        "NDEBUG",
+    ]
+)
+
+env_thirdparty = env_oidn.Clone()
+env_thirdparty.disable_warnings()
+env_thirdparty.add_source_files(env.modules_sources, thirdparty_sources)
+
+weights_in_path = thirdparty_dir + "weights/rtlightmap_hdr.tza"
+weights_out_path = thirdparty_dir + "weights/rtlightmap_hdr.gen.cpp"
+
+env_thirdparty.Depends(weights_out_path, weights_in_path)
+env_thirdparty.CommandNoCache(weights_out_path, weights_in_path, resource_to_cpp.tza_to_cpp)
+
+env_oidn.add_source_files(env.modules_sources, "denoise_wrapper.cpp")
+env_modules.add_source_files(env.modules_sources, ["register_types.cpp", "lightmap_denoiser.cpp"])

+ 15 - 0
modules/denoise/config.py

@@ -0,0 +1,15 @@
+def can_build(env, platform):
+    # Thirdparty dependency OpenImage Denoise includes oneDNN library
+    # which only supports 64-bit architectures.
+    # It's also only relevant for tools build and desktop platforms,
+    # as doing lightmap generation and denoising on Android or HTML5
+    # would be a bit far-fetched.
+    # Note: oneDNN doesn't support ARM64, OIDN needs updating to the latest version
+    supported_platform = platform in ["x11", "osx", "windows", "server"]
+    supported_bits = env["bits"] == "64"
+    supported_arch = env["arch"] != "arm64"
+    return env["tools"] and supported_platform and supported_bits and supported_arch
+
+
+def configure(env):
+    pass

+ 69 - 0
modules/denoise/denoise_wrapper.cpp

@@ -0,0 +1,69 @@
+/*************************************************************************/
+/*  denoise_wrapper.cpp                                                  */
+/*************************************************************************/
+/*                       This file is part of:                           */
+/*                           GODOT ENGINE                                */
+/*                      https://godotengine.org                          */
+/*************************************************************************/
+/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur.                 */
+/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md).   */
+/*                                                                       */
+/* Permission is hereby granted, free of charge, to any person obtaining */
+/* a copy of this software and associated documentation files (the       */
+/* "Software"), to deal in the Software without restriction, including   */
+/* without limitation the rights to use, copy, modify, merge, publish,   */
+/* distribute, sublicense, and/or sell copies of the Software, and to    */
+/* permit persons to whom the Software is furnished to do so, subject to */
+/* the following conditions:                                             */
+/*                                                                       */
+/* The above copyright notice and this permission notice shall be        */
+/* included in all copies or substantial portions of the Software.       */
+/*                                                                       */
+/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,       */
+/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF    */
+/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/
+/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY  */
+/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,  */
+/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE     */
+/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.                */
+/*************************************************************************/
+
+#include "denoise_wrapper.h"
+#include "core/os/copymem.h"
+#include "core/os/memory.h"
+#include "thirdparty/oidn/include/OpenImageDenoise/oidn.h"
+#include <stdio.h>
+
+void *oidn_denoiser_init() {
+	OIDNDeviceImpl *device = oidnNewDevice(OIDN_DEVICE_TYPE_CPU);
+	oidnCommitDevice(device);
+	return device;
+}
+
+bool oidn_denoise(void *deviceptr, float *p_floats, int p_width, int p_height) {
+	OIDNDeviceImpl *device = (OIDNDeviceImpl *)deviceptr;
+	OIDNFilter filter = oidnNewFilter(device, "RTLightmap");
+	void *input_buffer = memalloc(p_width * p_height * 3 * sizeof(float));
+	copymem(input_buffer, p_floats, p_width * p_height * 3 * sizeof(float));
+	oidnSetSharedFilterImage(filter, "color", input_buffer, OIDN_FORMAT_FLOAT3, p_width, p_height, 0, 0, 0);
+	oidnSetSharedFilterImage(filter, "output", (void *)p_floats, OIDN_FORMAT_FLOAT3, p_width, p_height, 0, 0, 0);
+	oidnSetFilter1b(filter, "hdr", true);
+	//oidnSetFilter1f(filter, "hdrScale", 1.0f);
+	//oidnSetFilter1i(filter, "verbose", 4);
+	oidnCommitFilter(filter);
+	oidnExecuteFilter(filter);
+
+	const char *msg;
+	bool success = true;
+	if (oidnGetDeviceError(device, &msg) != OIDN_ERROR_NONE) {
+		printf("LightmapDenoiser: %s\n", msg);
+		success = false;
+	}
+
+	oidnReleaseFilter(filter);
+	return success;
+}
+
+void oidn_denoiser_finish(void *device) {
+	oidnReleaseDevice((OIDNDeviceImpl *)device);
+}

+ 38 - 0
modules/denoise/denoise_wrapper.h

@@ -0,0 +1,38 @@
+/*************************************************************************/
+/*  denoise_wrapper.h                                                    */
+/*************************************************************************/
+/*                       This file is part of:                           */
+/*                           GODOT ENGINE                                */
+/*                      https://godotengine.org                          */
+/*************************************************************************/
+/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur.                 */
+/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md).   */
+/*                                                                       */
+/* Permission is hereby granted, free of charge, to any person obtaining */
+/* a copy of this software and associated documentation files (the       */
+/* "Software"), to deal in the Software without restriction, including   */
+/* without limitation the rights to use, copy, modify, merge, publish,   */
+/* distribute, sublicense, and/or sell copies of the Software, and to    */
+/* permit persons to whom the Software is furnished to do so, subject to */
+/* the following conditions:                                             */
+/*                                                                       */
+/* The above copyright notice and this permission notice shall be        */
+/* included in all copies or substantial portions of the Software.       */
+/*                                                                       */
+/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,       */
+/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF    */
+/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/
+/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY  */
+/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,  */
+/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE     */
+/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.                */
+/*************************************************************************/
+
+#ifndef DENOISE_WRAPPER_H
+#define DENOISE_WRAPPER_H
+
+void *oidn_denoiser_init();
+bool oidn_denoise(void *device, float *p_floats, int p_width, int p_height);
+void oidn_denoiser_finish(void *device);
+
+#endif // DENOISE_WRAPPER_H

+ 65 - 0
modules/denoise/lightmap_denoiser.cpp

@@ -0,0 +1,65 @@
+/*************************************************************************/
+/*  lightmap_denoiser.cpp                                                */
+/*************************************************************************/
+/*                       This file is part of:                           */
+/*                           GODOT ENGINE                                */
+/*                      https://godotengine.org                          */
+/*************************************************************************/
+/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur.                 */
+/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md).   */
+/*                                                                       */
+/* Permission is hereby granted, free of charge, to any person obtaining */
+/* a copy of this software and associated documentation files (the       */
+/* "Software"), to deal in the Software without restriction, including   */
+/* without limitation the rights to use, copy, modify, merge, publish,   */
+/* distribute, sublicense, and/or sell copies of the Software, and to    */
+/* permit persons to whom the Software is furnished to do so, subject to */
+/* the following conditions:                                             */
+/*                                                                       */
+/* The above copyright notice and this permission notice shall be        */
+/* included in all copies or substantial portions of the Software.       */
+/*                                                                       */
+/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,       */
+/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF    */
+/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/
+/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY  */
+/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,  */
+/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE     */
+/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.                */
+/*************************************************************************/
+
+#include "lightmap_denoiser.h"
+#include "denoise_wrapper.h"
+
+LightmapDenoiser *LightmapDenoiserOIDN::create_oidn_denoiser() {
+	return memnew(LightmapDenoiserOIDN);
+}
+
+void LightmapDenoiserOIDN::make_default_denoiser() {
+	create_function = create_oidn_denoiser;
+}
+
+Ref<Image> LightmapDenoiserOIDN::denoise_image(const Ref<Image> &p_image) {
+	Ref<Image> img = p_image->duplicate();
+
+	img->convert(Image::FORMAT_RGBF);
+
+	PoolByteArray data = img->get_data();
+	{
+		PoolByteArray::Write w = data.write();
+		if (!oidn_denoise(device, (float *)w.ptr(), img->get_width(), img->get_height())) {
+			return p_image;
+		}
+	}
+
+	img->create(img->get_width(), img->get_height(), false, img->get_format(), data);
+	return img;
+}
+
+LightmapDenoiserOIDN::LightmapDenoiserOIDN() {
+	device = oidn_denoiser_init();
+}
+
+LightmapDenoiserOIDN::~LightmapDenoiserOIDN() {
+	oidn_denoiser_finish(device);
+}

+ 56 - 0
modules/denoise/lightmap_denoiser.h

@@ -0,0 +1,56 @@
+/*************************************************************************/
+/*  lightmap_denoiser.h                                                  */
+/*************************************************************************/
+/*                       This file is part of:                           */
+/*                           GODOT ENGINE                                */
+/*                      https://godotengine.org                          */
+/*************************************************************************/
+/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur.                 */
+/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md).   */
+/*                                                                       */
+/* Permission is hereby granted, free of charge, to any person obtaining */
+/* a copy of this software and associated documentation files (the       */
+/* "Software"), to deal in the Software without restriction, including   */
+/* without limitation the rights to use, copy, modify, merge, publish,   */
+/* distribute, sublicense, and/or sell copies of the Software, and to    */
+/* permit persons to whom the Software is furnished to do so, subject to */
+/* the following conditions:                                             */
+/*                                                                       */
+/* The above copyright notice and this permission notice shall be        */
+/* included in all copies or substantial portions of the Software.       */
+/*                                                                       */
+/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,       */
+/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF    */
+/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/
+/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY  */
+/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,  */
+/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE     */
+/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.                */
+/*************************************************************************/
+
+#ifndef LIGHTMAP_DENOISER_H
+#define LIGHTMAP_DENOISER_H
+
+#include "core/class_db.h"
+#include "scene/3d/lightmapper.h"
+
+struct OIDNDeviceImpl;
+
+class LightmapDenoiserOIDN : public LightmapDenoiser {
+	GDCLASS(LightmapDenoiserOIDN, LightmapDenoiser);
+
+protected:
+	void *device = nullptr;
+
+public:
+	static LightmapDenoiser *create_oidn_denoiser();
+
+	Ref<Image> denoise_image(const Ref<Image> &p_image) override;
+
+	static void make_default_denoiser();
+
+	LightmapDenoiserOIDN();
+	~LightmapDenoiserOIDN();
+};
+
+#endif // LIGHTMAP_DENOISER_H

+ 40 - 0
modules/denoise/register_types.cpp

@@ -0,0 +1,40 @@
+/*************************************************************************/
+/*  register_types.cpp                                                   */
+/*************************************************************************/
+/*                       This file is part of:                           */
+/*                           GODOT ENGINE                                */
+/*                      https://godotengine.org                          */
+/*************************************************************************/
+/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur.                 */
+/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md).   */
+/*                                                                       */
+/* Permission is hereby granted, free of charge, to any person obtaining */
+/* a copy of this software and associated documentation files (the       */
+/* "Software"), to deal in the Software without restriction, including   */
+/* without limitation the rights to use, copy, modify, merge, publish,   */
+/* distribute, sublicense, and/or sell copies of the Software, and to    */
+/* permit persons to whom the Software is furnished to do so, subject to */
+/* the following conditions:                                             */
+/*                                                                       */
+/* The above copyright notice and this permission notice shall be        */
+/* included in all copies or substantial portions of the Software.       */
+/*                                                                       */
+/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,       */
+/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF    */
+/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/
+/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY  */
+/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,  */
+/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE     */
+/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.                */
+/*************************************************************************/
+
+#include "register_types.h"
+#include "core/engine.h"
+#include "lightmap_denoiser.h"
+
+void register_denoise_types() {
+	LightmapDenoiserOIDN::make_default_denoiser();
+}
+
+void unregister_denoise_types() {
+}

+ 37 - 0
modules/denoise/register_types.h

@@ -0,0 +1,37 @@
+/*************************************************************************/
+/*  register_types.h                                                     */
+/*************************************************************************/
+/*                       This file is part of:                           */
+/*                           GODOT ENGINE                                */
+/*                      https://godotengine.org                          */
+/*************************************************************************/
+/* Copyright (c) 2007-2020 Juan Linietsky, Ariel Manzur.                 */
+/* Copyright (c) 2014-2020 Godot Engine contributors (cf. AUTHORS.md).   */
+/*                                                                       */
+/* Permission is hereby granted, free of charge, to any person obtaining */
+/* a copy of this software and associated documentation files (the       */
+/* "Software"), to deal in the Software without restriction, including   */
+/* without limitation the rights to use, copy, modify, merge, publish,   */
+/* distribute, sublicense, and/or sell copies of the Software, and to    */
+/* permit persons to whom the Software is furnished to do so, subject to */
+/* the following conditions:                                             */
+/*                                                                       */
+/* The above copyright notice and this permission notice shall be        */
+/* included in all copies or substantial portions of the Software.       */
+/*                                                                       */
+/* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,       */
+/* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF    */
+/* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.*/
+/* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY  */
+/* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,  */
+/* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE     */
+/* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.                */
+/*************************************************************************/
+
+#ifndef DENOISE_REGISTER_TYPES_H
+#define DENOISE_REGISTER_TYPES_H
+
+void register_denoise_types();
+void unregister_denoise_types();
+
+#endif // DENOISE_REGISTER_TYPES_H

+ 68 - 0
modules/denoise/resource_to_cpp.py

@@ -0,0 +1,68 @@
+#!/usr/bin/env python
+
+## ======================================================================== ##
+## Copyright 2009-2019 Intel Corporation                                    ##
+##                                                                          ##
+## Licensed under the Apache License, Version 2.0 (the "License");          ##
+## you may not use this file except in compliance with the License.         ##
+## You may obtain a copy of the License at                                  ##
+##                                                                          ##
+##     http://www.apache.org/licenses/LICENSE-2.0                           ##
+##                                                                          ##
+## Unless required by applicable law or agreed to in writing, software      ##
+## distributed under the License is distributed on an "AS IS" BASIS,        ##
+## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ##
+## See the License for the specific language governing permissions and      ##
+## limitations under the License.                                           ##
+## ======================================================================== ##
+
+import os
+from array import array
+
+# Generates a C++ file from the specified binary resource file
+def generate(in_path, out_path):
+
+    namespace = "oidn::weights"
+    scopes = namespace.split("::")
+
+    file_name = os.path.basename(in_path)
+    var_name = os.path.splitext(file_name)[0]
+
+    with open(in_path, "rb") as in_file, open(out_path, "w") as out_file:
+        # Header
+        out_file.write("// Generated from: %s\n" % file_name)
+        out_file.write("#include <cstddef>\n\n")
+
+        # Open the namespaces
+        for s in scopes:
+            out_file.write("namespace %s {\n" % s)
+        if scopes:
+            out_file.write("\n")
+
+        # Read the file
+        in_data = array("B", in_file.read())
+
+        # Write the size
+        out_file.write("//const size_t %s_size = %d;\n\n" % (var_name, len(in_data)))
+
+        # Write the data
+        out_file.write("unsigned char %s[] = {" % var_name)
+        for i in range(len(in_data)):
+            c = in_data[i]
+            if i > 0:
+                out_file.write(",")
+            if (i + 1) % 20 == 1:
+                out_file.write("\n")
+            out_file.write("%d" % c)
+        out_file.write("\n};\n")
+
+        # Close the namespaces
+        if scopes:
+            out_file.write("\n")
+        for scope in reversed(scopes):
+            out_file.write("} // namespace %s\n" % scope)
+
+
+def tza_to_cpp(target, source, env):
+    for x in zip(source, target):
+        generate(str(x[0]), str(x[1]))

+ 31 - 0
thirdparty/README.md

@@ -360,6 +360,37 @@ Files extracted from the upstream source:
 - LICENSE.txt
 - LICENSE.txt
 
 
 
 
+## oidn
+
+- Upstream: https://github.com/OpenImageDenoise/oidn
+- Version: 1.1.0 (c58c5216db05ceef4cde5a096862f2eeffd14c06, 2019)
+- License: Apache 2.0
+
+Files extracted from upstream source:
+
+common/* (except tasking.* and CMakeLists.txt)
+core/*
+include/OpenImageDenoise/* (except version.h.in)
+LICENSE.txt
+mkl-dnn/include/*
+mkl-dnn/src/* (except CMakeLists.txt)
+weights/rtlightmap_hdr.tza
+scripts/resource_to_cpp.py
+
+Modified files:
+Modifications are marked with `// -- GODOT start --` and `// -- GODOT end --`.
+Patch files are provided in `oidn/patches/`.
+
+core/autoencoder.cpp
+core/autoencoder.h
+core/common.h
+core/device.cpp
+core/device.h
+core/transfer_function.cpp
+
+scripts/resource_to_cpp.py (used in modules/denoise/resource_to_cpp.py)
+
+
 ## opus
 ## opus
 
 
 - Upstream: https://opus-codec.org
 - Upstream: https://opus-codec.org

+ 202 - 0
thirdparty/oidn/LICENSE.txt

@@ -0,0 +1,202 @@
+
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "[]"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright [yyyy] [name of copyright owner]
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.

+ 52 - 0
thirdparty/oidn/common/barrier.h

@@ -0,0 +1,52 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include "platform.h"
+#include <mutex>
+#include <condition_variable>
+
+namespace oidn {
+
+  class Barrier
+  {
+  private:
+    std::mutex m;
+    std::condition_variable cv;
+    volatile int count;
+
+  public:
+    Barrier(int count) : count(count) {}
+
+    void wait()
+    {
+      std::unique_lock<std::mutex> lk(m);
+      count--;
+
+      if (count == 0)
+      {
+        lk.unlock();
+        cv.notify_all();
+      }
+      else
+      {
+        cv.wait(lk, [&]{ return count == 0; });
+      }
+    }
+  };
+
+} // namespace oidn

+ 45 - 0
thirdparty/oidn/common/exception.h

@@ -0,0 +1,45 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include <exception>
+#include "platform.h"
+
+namespace oidn {
+
+  class Exception : public std::exception
+  {
+  private:
+    Error error;
+    const char* message;
+
+  public:
+    Exception(Error error, const char* message)
+      : error(error), message(message) {}
+
+    Error code() const noexcept
+    {
+      return error;
+    }
+
+    const char* what() const noexcept override
+    {
+      return message;
+    }
+  };
+
+} // namespace oidn

+ 114 - 0
thirdparty/oidn/common/platform.cpp

@@ -0,0 +1,114 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#include "platform.h"
+
+namespace oidn {
+
+  // ----------------------------------------------------------------------------
+  // Common functions
+  // ----------------------------------------------------------------------------
+
+  void* alignedMalloc(size_t size, size_t alignment)
+  {
+    if (size == 0)
+      return nullptr;
+
+    assert((alignment & (alignment-1)) == 0);
+    void* ptr = _mm_malloc(size, alignment);
+
+    if (ptr == nullptr)
+      throw std::bad_alloc();
+
+    return ptr;
+  }
+
+  void alignedFree(void* ptr)
+  {
+    if (ptr)
+      _mm_free(ptr);
+  }
+
+  // ----------------------------------------------------------------------------
+  // System information
+  // ----------------------------------------------------------------------------
+
+  std::string getPlatformName()
+  {
+    std::string name;
+
+  #if defined(__linux__)
+    name = "Linux";
+  #elif defined(__FreeBSD__)
+    name = "FreeBSD";
+  #elif defined(__CYGWIN__)
+    name = "Cygwin";
+  #elif defined(_WIN32)
+    name = "Windows";
+  #elif defined(__APPLE__)
+    name = "macOS";
+  #elif defined(__unix__)
+    name = "Unix";
+  #else
+    return "Unknown";
+  #endif
+
+  #if defined(__x86_64__) || defined(_M_X64) || defined(__ia64__) || defined(__aarch64__)
+    name += " (64-bit)";
+  #else
+    name += " (32-bit)";
+  #endif
+
+    return name;
+  }
+
+  std::string getCompilerName()
+  {
+  #if defined(__INTEL_COMPILER)
+    int mayor = __INTEL_COMPILER / 100 % 100;
+    int minor = __INTEL_COMPILER % 100;
+    std::string version = "Intel Compiler ";
+    version += toString(mayor);
+    version += "." + toString(minor);
+  #if defined(__INTEL_COMPILER_UPDATE)
+    version += "." + toString(__INTEL_COMPILER_UPDATE);
+  #endif
+    return version;
+  #elif defined(__clang__)
+    return "Clang " __clang_version__;
+  #elif defined(__GNUC__)
+    return "GCC " __VERSION__;
+  #elif defined(_MSC_VER)
+    std::string version = toString(_MSC_FULL_VER);
+    version.insert(4, ".");
+    version.insert(9, ".");
+    version.insert(2, ".");
+    return "Visual C++ Compiler " + version;
+  #else
+    return "Unknown";
+  #endif
+  }
+
+  std::string getBuildName()
+  {
+  #if defined(NDEBUG)
+    return "Release";
+  #else
+    return "Debug";
+  #endif
+  }
+
+} // namespace oidn

+ 131 - 0
thirdparty/oidn/common/platform.h

@@ -0,0 +1,131 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#if defined(_WIN32)
+  #define WIN32_LEAN_AND_MEAN
+  #define NOMINMAX
+  #include <windows.h>
+#elif defined(__APPLE__)
+  #include <sys/sysctl.h>
+#endif
+
+#include <xmmintrin.h>
+#include <cstdint>
+#include <climits>
+#include <limits>
+#include <atomic>
+#include <algorithm>
+#include <memory>
+#include <cmath>
+#include <string>
+#include <sstream>
+#include <iostream>
+#include <cassert>
+#include "include/OpenImageDenoise/oidn.hpp"
+
+namespace oidn {
+
+  // ----------------------------------------------------------------------------
+  // Macros
+  // ----------------------------------------------------------------------------
+
+  #if defined(_WIN32)
+    // Windows
+    #if !defined(__noinline)
+      #define __noinline     __declspec(noinline)
+    #endif
+  #else
+    // Unix
+    #if !defined(__forceinline)
+      #define __forceinline  inline __attribute__((always_inline))
+    #endif
+    #if !defined(__noinline)
+      #define __noinline     __attribute__((noinline))
+    #endif
+  #endif
+
+  #ifndef UNUSED
+    #define UNUSED(x) ((void)x)
+  #endif
+  #ifndef MAYBE_UNUSED
+    #define MAYBE_UNUSED(x) UNUSED(x)
+  #endif
+
+  // ----------------------------------------------------------------------------
+  // Error handling and debugging
+  // ----------------------------------------------------------------------------
+
+  struct Verbose
+  {
+    int verbose;
+
+    Verbose(int v = 0) : verbose(v) {}
+    __forceinline bool isVerbose(int v = 1) const { return v <= verbose; }
+  };
+
+  #define OIDN_WARNING(message) { if (isVerbose()) std::cerr << "Warning: " << message << std::endl; }
+  #define OIDN_FATAL(message) throw std::runtime_error(message);
+
+  // ----------------------------------------------------------------------------
+  // Common functions
+  // ----------------------------------------------------------------------------
+
+  using std::min;
+  using std::max;
+
+  template<typename T>
+  __forceinline T clamp(const T& value, const T& minValue, const T& maxValue)
+  {
+    return min(max(value, minValue), maxValue);
+  }
+
+  void* alignedMalloc(size_t size, size_t alignment);
+  void alignedFree(void* ptr);
+
+  template<typename T>
+  inline std::string toString(const T& a)
+  {
+    std::stringstream sm;
+    sm << a;
+    return sm.str();
+  }
+
+#if defined(__APPLE__)
+  template<typename T>
+  bool getSysctl(const char* name, T& value)
+  {
+    int64_t result = 0;
+    size_t size = sizeof(result);
+
+    if (sysctlbyname(name, &result, &size, nullptr, 0) != 0)
+      return false;
+
+    value = T(result);
+    return true;
+  }
+#endif
+
+  // ----------------------------------------------------------------------------
+  // System information
+  // ----------------------------------------------------------------------------
+
+  std::string getPlatformName();
+  std::string getCompilerName();
+  std::string getBuildName();
+
+} // namespace oidn

+ 163 - 0
thirdparty/oidn/common/ref.h

@@ -0,0 +1,163 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include "platform.h"
+
+namespace oidn {
+
+  class RefCount
+  {
+  private:
+    std::atomic<size_t> count;
+
+  public:
+    __forceinline RefCount(int count = 0) noexcept : count(count) {}
+
+    __forceinline size_t incRef() noexcept
+    {
+      return count.fetch_add(1) + 1;
+    }
+
+    __forceinline size_t decRef()
+    {
+      const size_t newCount = decRefKeep();
+      if (newCount == 0)
+        destroy();
+      return newCount;
+    }
+
+    __forceinline size_t decRefKeep() noexcept
+    {
+      return count.fetch_add(-1) - 1;
+    }
+
+    __forceinline void destroy()
+    {
+      delete this;
+    }
+
+  protected:
+    // Disable copying
+    RefCount(const RefCount&) = delete;
+    RefCount& operator =(const RefCount&) = delete;
+
+    virtual ~RefCount() noexcept = default;
+  };
+
+  template<typename T>
+  class Ref
+  {
+  private:
+    T* ptr;
+
+  public:
+    __forceinline Ref() noexcept : ptr(nullptr) {}
+    __forceinline Ref(std::nullptr_t) noexcept : ptr(nullptr) {}
+    __forceinline Ref(const Ref& other) noexcept : ptr(other.ptr) { if (ptr) ptr->incRef(); }
+    __forceinline Ref(Ref&& other) noexcept : ptr(other.ptr) { other.ptr = nullptr; }
+    __forceinline Ref(T* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); }
+
+    template<typename Y>
+    __forceinline Ref(const Ref<Y>& other) noexcept : ptr(other.get()) { if (ptr) ptr->incRef(); }
+
+    template<typename Y>
+    __forceinline explicit Ref(Y* ptr) noexcept : ptr(ptr) { if (ptr) ptr->incRef(); }
+
+    __forceinline ~Ref() { if (ptr) ptr->decRef(); }
+
+    __forceinline Ref& operator =(const Ref& other)
+    {
+      if (other.ptr)
+        other.ptr->incRef();
+      if (ptr)
+        ptr->decRef();
+      ptr = other.ptr;
+      return *this;
+    }
+
+    __forceinline Ref& operator =(Ref&& other)
+    {
+      if (ptr)
+        ptr->decRef();
+      ptr = other.ptr;
+      other.ptr = nullptr;
+      return *this;
+    }
+
+    __forceinline Ref& operator =(T* other)
+    {
+      if (other)
+        other->incRef();
+      if (ptr)
+        ptr->decRef();
+      ptr = other;
+      return *this;
+    }
+
+    __forceinline Ref& operator =(std::nullptr_t)
+    {
+      if (ptr)
+        ptr->decRef();
+      ptr = nullptr;
+      return *this;
+    }
+
+    __forceinline operator bool() const noexcept { return ptr != nullptr; }
+
+    __forceinline T& operator  *() const noexcept { return *ptr; }
+    __forceinline T* operator ->() const noexcept { return  ptr; }
+
+    __forceinline T* get() const noexcept { return ptr; }
+
+    __forceinline T* detach() noexcept
+    {
+      T* res = ptr;
+      ptr = nullptr;
+      return res;
+    }
+  };
+
+  template<typename T> __forceinline bool operator < (const Ref<T>& a, const Ref<T>& b) noexcept { return a.ptr   <  b.ptr;   }
+
+  template<typename T> __forceinline bool operator ==(const Ref<T>& a, std::nullptr_t)  noexcept { return a.ptr   == nullptr; }
+  template<typename T> __forceinline bool operator ==(std::nullptr_t,  const Ref<T>& b) noexcept { return nullptr == b.ptr;   }
+  template<typename T> __forceinline bool operator ==(const Ref<T>& a, const Ref<T>& b) noexcept { return a.ptr   == b.ptr;   }
+
+  template<typename T> __forceinline bool operator !=(const Ref<T>& a, std::nullptr_t)  noexcept { return a.ptr   != nullptr; }
+  template<typename T> __forceinline bool operator !=(std::nullptr_t,  const Ref<T>& b) noexcept { return nullptr != b.ptr;   }
+  template<typename T> __forceinline bool operator !=(const Ref<T>& a, const Ref<T>& b) noexcept { return a.ptr   != b.ptr;   }
+
+  template<typename T, typename... Args>
+  __forceinline Ref<T> makeRef(Args&&... args)
+  {
+    return Ref<T>(new T(std::forward<Args>(args)...));
+  }
+
+  template<typename T, typename Y>
+  __forceinline Ref<Y> staticRefCast(const Ref<T>& a)
+  {
+    return Ref<Y>(static_cast<Y*>(a.get()));
+  }
+
+  template<typename T, typename Y>
+  __forceinline Ref<Y> dynamicRefCast(const Ref<T>& a)
+  {
+    return Ref<Y>(dynamic_cast<Y*>(a.get()));
+  }
+
+} // namespace oidn

+ 83 - 0
thirdparty/oidn/common/tensor.cpp

@@ -0,0 +1,83 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#include "exception.h"
+#include "tensor.h"
+
+namespace oidn {
+
+  std::map<std::string, Tensor> parseTensors(void* buffer)
+  {
+    char* input = (char*)buffer;
+
+    // Parse the magic value
+    const int magic = *(unsigned short*)input;
+    if (magic != 0x41D7)
+      throw Exception(Error::InvalidOperation, "invalid tensor archive");
+    input += sizeof(unsigned short);
+
+    // Parse the version
+    const int majorVersion = *(unsigned char*)input++;
+    const int minorVersion = *(unsigned char*)input++;
+    UNUSED(minorVersion);
+    if (majorVersion > 1)
+      throw Exception(Error::InvalidOperation, "unsupported tensor archive version");
+
+    // Parse the number of tensors
+    const int numTensors = *(int*)input;
+    input += sizeof(int);
+
+    // Parse the tensors
+    std::map<std::string, Tensor> tensorMap;
+    for (int i = 0; i < numTensors; ++i)
+    {
+      Tensor tensor;
+
+      // Parse the name
+      const int nameLen = *(unsigned char*)input++;
+      std::string name(input, nameLen);
+      input += nameLen;
+
+      // Parse the number of dimensions
+      const int ndims = *(unsigned char*)input++;
+
+      // Parse the shape of the tensor
+      tensor.dims.resize(ndims);
+      for (int i = 0; i < ndims; ++i)
+        tensor.dims[i] = ((int*)input)[i];
+      input += ndims * sizeof(int);
+
+      // Parse the format of the tensor
+      tensor.format = std::string(input, input + ndims);
+      input += ndims;
+
+      // Parse the data type of the tensor
+      const char type = *(unsigned char*)input++;
+      if (type != 'f') // only float32 is supported
+        throw Exception(Error::InvalidOperation, "unsupported tensor data type");
+
+      // Skip the data
+      tensor.data = (float*)input;
+      input += tensor.size() * sizeof(float);
+
+      // Add the tensor to the map
+      tensorMap.emplace(name, std::move(tensor));
+    }
+
+    return tensorMap;
+  }
+
+} // namespace oidn

+ 66 - 0
thirdparty/oidn/common/tensor.h

@@ -0,0 +1,66 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include "platform.h"
+#include <vector>
+#include <map>
+
+namespace oidn {
+
+  template<typename T>
+  using shared_vector = std::shared_ptr<std::vector<T>>;
+
+  // Generic tensor
+  struct Tensor
+  {
+    float* data;
+    std::vector<int64_t> dims;
+    std::string format;
+    shared_vector<char> buffer; // optional, only for reference counting
+
+    __forceinline Tensor() : data(nullptr) {}
+
+    __forceinline Tensor(const std::vector<int64_t>& dims, const std::string& format)
+      : dims(dims),
+        format(format)
+    {
+      buffer = std::make_shared<std::vector<char>>(size() * sizeof(float));
+      data = (float*)buffer->data();
+    }
+
+    __forceinline operator bool() const { return data != nullptr; }
+
+    __forceinline int ndims() const { return (int)dims.size(); }
+
+    // Returns the number of values
+    __forceinline size_t size() const
+    {
+      size_t size = 1;
+      for (int i = 0; i < ndims(); ++i)
+        size *= dims[i];
+      return size;
+    }
+
+    __forceinline float& operator [](size_t i) { return data[i]; }
+    __forceinline const float& operator [](size_t i) const { return data[i]; }
+  };
+
+  // Parses tensors from a buffer
+  std::map<std::string, Tensor> parseTensors(void* buffer);
+
+} // namespace oidn

+ 297 - 0
thirdparty/oidn/common/thread.cpp

@@ -0,0 +1,297 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#if defined(_MSC_VER)
+  #pragma warning (disable : 4146) // unary minus operator applied to unsigned type, result still unsigned
+#endif
+
+#if defined(__APPLE__)
+  #include <mach/thread_act.h>
+  #include <mach/mach_init.h>
+#endif
+
+#include "thread.h"
+#include <fstream>
+
+namespace oidn {
+
+#if defined(_WIN32)
+
+  // --------------------------------------------------------------------------
+  // ThreadAffinity - Windows
+  // --------------------------------------------------------------------------
+
+  ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose)
+    : Verbose(verbose)
+  {
+    HMODULE hLib = GetModuleHandle(TEXT("kernel32"));
+    pGetLogicalProcessorInformationEx = (GetLogicalProcessorInformationExFunc)GetProcAddress(hLib, "GetLogicalProcessorInformationEx");
+    pSetThreadGroupAffinity = (SetThreadGroupAffinityFunc)GetProcAddress(hLib, "SetThreadGroupAffinity");
+
+    if (pGetLogicalProcessorInformationEx && pSetThreadGroupAffinity)
+    {
+      // Get logical processor information
+      PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX buffer = nullptr;
+      DWORD bufferSize = 0;
+
+      // First call the function with an empty buffer to get the required buffer size
+      BOOL result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize);
+      if (result || GetLastError() != ERROR_INSUFFICIENT_BUFFER)
+      {
+        OIDN_WARNING("GetLogicalProcessorInformationEx failed");
+        return;
+      }
+
+      // Allocate the buffer
+      buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)malloc(bufferSize);
+      if (!buffer)
+      {
+        OIDN_WARNING("SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX allocation failed");
+        return;
+      }
+
+      // Call again the function but now with the properly sized buffer
+      result = pGetLogicalProcessorInformationEx(RelationProcessorCore, buffer, &bufferSize);
+      if (!result)
+      {
+        OIDN_WARNING("GetLogicalProcessorInformationEx failed");
+        free(buffer);
+        return;
+      }
+
+      // Iterate over the logical processor information structures
+      // There should be one structure for each physical core
+      char* ptr = (char*)buffer;
+      while (ptr < (char*)buffer + bufferSize)
+      {
+        PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX item = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX)ptr;
+        if (item->Relationship == RelationProcessorCore && item->Processor.GroupCount > 0)
+        {
+          // Iterate over the groups
+          int numThreads = 0;
+          for (int group = 0; (group < item->Processor.GroupCount) && (numThreads < numThreadsPerCore); ++group)
+          {
+            GROUP_AFFINITY coreAffinity = item->Processor.GroupMask[group];
+            while ((coreAffinity.Mask != 0) && (numThreads < numThreadsPerCore))
+            {
+              // Extract the next set bit/thread from the mask
+              GROUP_AFFINITY threadAffinity = coreAffinity;
+              threadAffinity.Mask = threadAffinity.Mask & -threadAffinity.Mask;
+
+              // Push the affinity for this thread
+              affinities.push_back(threadAffinity);
+              oldAffinities.push_back(threadAffinity);
+              numThreads++;
+
+              // Remove this bit/thread from the mask
+              coreAffinity.Mask ^= threadAffinity.Mask;
+            }
+          }
+        }
+
+        // Next structure
+        ptr += item->Size;
+      }
+
+      // Free the buffer
+      free(buffer);
+    }
+  }
+
+  void ThreadAffinity::set(int threadIndex)
+  {
+    if (threadIndex >= (int)affinities.size())
+      return;
+
+    // Save the current affinity and set the new one
+    const HANDLE thread = GetCurrentThread();
+    if (!pSetThreadGroupAffinity(thread, &affinities[threadIndex], &oldAffinities[threadIndex]))
+      OIDN_WARNING("SetThreadGroupAffinity failed");
+  }
+
+  void ThreadAffinity::restore(int threadIndex)
+  {
+    if (threadIndex >= (int)affinities.size())
+      return;
+
+    // Restore the original affinity
+    const HANDLE thread = GetCurrentThread();
+    if (!pSetThreadGroupAffinity(thread, &oldAffinities[threadIndex], nullptr))
+      OIDN_WARNING("SetThreadGroupAffinity failed");
+  }
+
+#elif defined(__linux__)
+
+  // --------------------------------------------------------------------------
+  // ThreadAffinity - Linux
+  // --------------------------------------------------------------------------
+
+  ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose)
+    : Verbose(verbose)
+  {
+    std::vector<int> threadIds;
+
+    // Parse the thread/CPU topology
+    for (int cpuId = 0; ; cpuId++)
+    {
+      std::fstream fs;
+      std::string cpu = std::string("/sys/devices/system/cpu/cpu") + std::to_string(cpuId) + std::string("/topology/thread_siblings_list");
+      fs.open(cpu.c_str(), std::fstream::in);
+      if (fs.fail()) break;
+
+      int i;
+      int j = 0;
+      while ((j < numThreadsPerCore) && (fs >> i))
+      {
+        if (std::none_of(threadIds.begin(), threadIds.end(), [&](int id) { return id == i; }))
+          threadIds.push_back(i);
+
+        if (fs.peek() == ',')
+          fs.ignore();
+        j++;
+      }
+
+      fs.close();
+    }
+
+  #if 0
+    for (size_t i = 0; i < thread_ids.size(); ++i)
+      std::cout << "thread " << i << " -> " << thread_ids[i] << std::endl;
+  #endif
+
+    // Create the affinity structures
+    affinities.resize(threadIds.size());
+    oldAffinities.resize(threadIds.size());
+
+    for (size_t i = 0; i < threadIds.size(); ++i)
+    {
+      cpu_set_t affinity;
+      CPU_ZERO(&affinity);
+      CPU_SET(threadIds[i], &affinity);
+
+      affinities[i] = affinity;
+      oldAffinities[i] = affinity;
+    }
+  }
+
+  void ThreadAffinity::set(int threadIndex)
+  {
+    if (threadIndex >= (int)affinities.size())
+      return;
+
+    const pthread_t thread = pthread_self();
+
+    // Save the current affinity
+    if (pthread_getaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0)
+    {
+      OIDN_WARNING("pthread_getaffinity_np failed");
+      oldAffinities[threadIndex] = affinities[threadIndex];
+      return;
+    }
+
+    // Set the new affinity
+    if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &affinities[threadIndex]) != 0)
+      OIDN_WARNING("pthread_setaffinity_np failed");
+  }
+
+  void ThreadAffinity::restore(int threadIndex)
+  {
+    if (threadIndex >= (int)affinities.size())
+      return;
+
+    const pthread_t thread = pthread_self();
+
+    // Restore the original affinity
+    if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &oldAffinities[threadIndex]) != 0)
+      OIDN_WARNING("pthread_setaffinity_np failed");
+  }
+
+#elif defined(__APPLE__)
+
+  // --------------------------------------------------------------------------
+  // ThreadAffinity - macOS
+  // --------------------------------------------------------------------------
+
+  ThreadAffinity::ThreadAffinity(int numThreadsPerCore, int verbose)
+    : Verbose(verbose)
+  {
+    // Query the thread/CPU topology
+    int numPhysicalCpus;
+    int numLogicalCpus;
+
+    if (!getSysctl("hw.physicalcpu", numPhysicalCpus) || !getSysctl("hw.logicalcpu", numLogicalCpus))
+    {
+      OIDN_WARNING("sysctlbyname failed");
+      return;
+    }
+
+    if ((numLogicalCpus % numPhysicalCpus != 0) && (numThreadsPerCore > 1))
+      return; // this shouldn't happen
+    const int maxThreadsPerCore = numLogicalCpus / numPhysicalCpus;
+
+    // Create the affinity structures
+    // macOS doesn't support binding a thread to a specific core, but we can at least group threads which
+    // should be on the same core together
+    for (int core = 1; core <= numPhysicalCpus; ++core) // tags start from 1!
+    {
+      thread_affinity_policy affinity;
+      affinity.affinity_tag = core;
+
+      for (int thread = 0; thread < min(numThreadsPerCore, maxThreadsPerCore); ++thread)
+      {
+        affinities.push_back(affinity);
+        oldAffinities.push_back(affinity);
+      }
+    }
+  }
+
+  void ThreadAffinity::set(int threadIndex)
+  {
+    if (threadIndex >= (int)affinities.size())
+      return;
+
+    const auto thread = mach_thread_self();
+
+    // Save the current affinity
+    mach_msg_type_number_t policyCount = THREAD_AFFINITY_POLICY_COUNT;
+    boolean_t getDefault = FALSE;
+    if (thread_policy_get(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], &policyCount, &getDefault) != KERN_SUCCESS)
+    {
+      OIDN_WARNING("thread_policy_get failed");
+      oldAffinities[threadIndex] = affinities[threadIndex];
+      return;
+    }
+
+    // Set the new affinity
+    if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&affinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS)
+      OIDN_WARNING("thread_policy_set failed");
+  }
+
+  void ThreadAffinity::restore(int threadIndex)
+  {
+    if (threadIndex >= (int)affinities.size())
+      return;
+
+    const auto thread = mach_thread_self();
+
+    // Restore the original affinity
+    if (thread_policy_set(thread, THREAD_AFFINITY_POLICY, (thread_policy_t)&oldAffinities[threadIndex], THREAD_AFFINITY_POLICY_COUNT) != KERN_SUCCESS)
+      OIDN_WARNING("thread_policy_set failed");
+  }
+
+#endif
+
+} // namespace oidn

+ 202 - 0
thirdparty/oidn/common/thread.h

@@ -0,0 +1,202 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include "platform.h"
+
+#if !defined(_WIN32)
+  #include <pthread.h>
+  #include <sched.h>
+  #if defined(__APPLE__)
+    #include <mach/thread_policy.h>
+  #endif
+#endif
+
+#include <vector>
+#include <mutex>
+
+namespace oidn {
+
+  // --------------------------------------------------------------------------
+  // ThreadLocal
+  // --------------------------------------------------------------------------
+
+  // Wrapper which makes any variable thread-local
+  template<typename T>
+  class ThreadLocal : public Verbose
+  {
+  private:
+  #if defined(_WIN32)
+    DWORD key;
+  #else
+    pthread_key_t key;
+  #endif
+
+    std::vector<T*> instances;
+    std::mutex mutex;
+
+  public:
+    ThreadLocal(int verbose = 0)
+      : Verbose(verbose)
+    {
+    #if defined(_WIN32)
+      key = TlsAlloc();
+      if (key == TLS_OUT_OF_INDEXES)
+        OIDN_FATAL("TlsAlloc failed");
+    #else
+      if (pthread_key_create(&key, nullptr) != 0)
+        OIDN_FATAL("pthread_key_create failed");
+    #endif
+    }
+
+    ~ThreadLocal()
+    {
+      std::lock_guard<std::mutex> lock(mutex);
+      for (T* ptr : instances)
+        delete ptr;
+
+    #if defined(_WIN32)
+      if (!TlsFree(key))
+        OIDN_WARNING("TlsFree failed");
+    #else
+      if (pthread_key_delete(key) != 0)
+        OIDN_WARNING("pthread_key_delete failed");
+    #endif
+    }
+
+    T& get()
+    {
+    #if defined(_WIN32)
+      T* ptr = (T*)TlsGetValue(key);
+    #else
+      T* ptr = (T*)pthread_getspecific(key);
+    #endif
+
+      if (ptr)
+        return *ptr;
+
+      ptr = new T;
+      std::lock_guard<std::mutex> lock(mutex);
+      instances.push_back(ptr);
+
+    #if defined(_WIN32)
+      if (!TlsSetValue(key, ptr))
+        OIDN_FATAL("TlsSetValue failed");
+    #else
+      if (pthread_setspecific(key, ptr) != 0)
+        OIDN_FATAL("pthread_setspecific failed");
+    #endif
+
+      return *ptr;
+    }
+  };
+
+#if defined(_WIN32)
+
+  // --------------------------------------------------------------------------
+  // ThreadAffinity - Windows
+  // --------------------------------------------------------------------------
+
+  class ThreadAffinity : public Verbose
+  {
+  private:
+    typedef BOOL (WINAPI *GetLogicalProcessorInformationExFunc)(LOGICAL_PROCESSOR_RELATIONSHIP,
+                                                                PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX,
+                                                                PDWORD);
+
+    typedef BOOL (WINAPI *SetThreadGroupAffinityFunc)(HANDLE,
+                                                      CONST GROUP_AFFINITY*,
+                                                      PGROUP_AFFINITY);
+
+    GetLogicalProcessorInformationExFunc pGetLogicalProcessorInformationEx = nullptr;
+    SetThreadGroupAffinityFunc pSetThreadGroupAffinity = nullptr;
+
+    std::vector<GROUP_AFFINITY> affinities;    // thread affinities
+    std::vector<GROUP_AFFINITY> oldAffinities; // original thread affinities
+
+  public:
+    ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0);
+
+    int getNumThreads() const
+    {
+      return (int)affinities.size();
+    }
+
+    // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity
+    void set(int threadIndex);
+
+    // Restores the affinity of the thread
+    void restore(int threadIndex);
+  };
+
+#elif defined(__linux__)
+
+  // --------------------------------------------------------------------------
+  // ThreadAffinity - Linux
+  // --------------------------------------------------------------------------
+
+  class ThreadAffinity : public Verbose
+  {
+  private:
+    std::vector<cpu_set_t> affinities;    // thread affinities
+    std::vector<cpu_set_t> oldAffinities; // original thread affinities
+
+  public:
+    ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0);
+
+    int getNumThreads() const
+    {
+      return (int)affinities.size();
+    }
+
+    // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity
+    void set(int threadIndex);
+
+    // Restores the affinity of the thread
+    void restore(int threadIndex);
+  };
+
+#elif defined(__APPLE__)
+
+  // --------------------------------------------------------------------------
+  // ThreadAffinity - macOS
+  // --------------------------------------------------------------------------
+
+  class ThreadAffinity : public Verbose
+  {
+  private:
+    std::vector<thread_affinity_policy> affinities;    // thread affinities
+    std::vector<thread_affinity_policy> oldAffinities; // original thread affinities
+
+  public:
+    ThreadAffinity(int numThreadsPerCore = INT_MAX, int verbose = 0);
+
+    int getNumThreads() const
+    {
+      return (int)affinities.size();
+    }
+
+    // Sets the affinity (0..numThreads-1) of the thread after saving the current affinity
+    void set(int threadIndex);
+
+    // Restores the affinity of the thread
+    void restore(int threadIndex);
+  };
+
+#endif
+
+} // namespace oidn

+ 49 - 0
thirdparty/oidn/common/timer.h

@@ -0,0 +1,49 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include "platform.h"
+#include <chrono>
+
+namespace oidn {
+
+  class Timer
+  {
+  private:
+    using clock = std::chrono::high_resolution_clock;
+
+    std::chrono::time_point<clock> start;
+
+  public:
+    Timer()
+    {
+      reset();
+    }
+
+    void reset()
+    {
+      start = clock::now();
+    }
+
+    double query() const
+    {
+      auto end = clock::now();
+      return std::chrono::duration_cast<std::chrono::duration<double>>(end - start).count();
+    }
+  };
+
+} // namespace oidn

+ 408 - 0
thirdparty/oidn/core/api.cpp

@@ -0,0 +1,408 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#ifdef _WIN32
+#  define OIDN_API extern "C" __declspec(dllexport)
+#else
+#  define OIDN_API extern "C" __attribute__ ((visibility ("default")))
+#endif
+
+// Locks the device that owns the specified object
+// Use *only* inside OIDN_TRY/CATCH!
+#define OIDN_LOCK(obj) \
+  std::lock_guard<std::mutex> lock(obj->getDevice()->getMutex());
+
+// Try/catch for converting exceptions to errors
+#define OIDN_TRY \
+  try {
+
+#define OIDN_CATCH(obj) \
+  } catch (Exception& e) {                                                                          \
+    Device::setError(obj ? obj->getDevice() : nullptr, e.code(), e.what());                         \
+  } catch (std::bad_alloc&) {                                                                       \
+    Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory");        \
+  } catch (mkldnn::error& e) {                                                                      \
+    if (e.status == mkldnn_out_of_memory)                                                           \
+      Device::setError(obj ? obj->getDevice() : nullptr, Error::OutOfMemory, "out of memory");      \
+    else                                                                                            \
+      Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.message);                \
+  } catch (std::exception& e) {                                                                     \
+    Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, e.what());                   \
+  } catch (...) {                                                                                   \
+    Device::setError(obj ? obj->getDevice() : nullptr, Error::Unknown, "unknown exception caught"); \
+  }
+
+#include "device.h"
+#include "filter.h"
+#include <mutex>
+
+namespace oidn {
+
+  namespace
+  {
+    __forceinline void checkHandle(void* handle)
+    {
+      if (handle == nullptr)
+        throw Exception(Error::InvalidArgument, "invalid handle");
+    }
+
+    template<typename T>
+    __forceinline void retainObject(T* obj)
+    {
+      if (obj)
+      {
+        obj->incRef();
+      }
+      else
+      {
+        OIDN_TRY
+          checkHandle(obj);
+        OIDN_CATCH(obj)
+      }
+    }
+
+    template<typename T>
+    __forceinline void releaseObject(T* obj)
+    {
+      if (obj == nullptr || obj->decRefKeep() == 0)
+      {
+        OIDN_TRY
+          checkHandle(obj);
+          OIDN_LOCK(obj);
+          obj->destroy();
+        OIDN_CATCH(obj)
+      }
+    }
+
+    template<>
+    __forceinline void releaseObject(Device* obj)
+    {
+      if (obj == nullptr || obj->decRefKeep() == 0)
+      {
+        OIDN_TRY
+          checkHandle(obj);
+          // Do NOT lock the device because it owns the mutex
+          obj->destroy();
+        OIDN_CATCH(obj)
+      }
+    }
+  }
+
+  OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type)
+  {
+    Ref<Device> device = nullptr;
+    OIDN_TRY
+      if (type == OIDN_DEVICE_TYPE_CPU || type == OIDN_DEVICE_TYPE_DEFAULT)
+        device = makeRef<Device>();
+      else
+        throw Exception(Error::InvalidArgument, "invalid device type");
+    OIDN_CATCH(device)
+    return (OIDNDevice)device.detach();
+  }
+
+  OIDN_API void oidnRetainDevice(OIDNDevice hDevice)
+  {
+    Device* device = (Device*)hDevice;
+    retainObject(device);
+  }
+
+  OIDN_API void oidnReleaseDevice(OIDNDevice hDevice)
+  {
+    Device* device = (Device*)hDevice;
+    releaseObject(device);
+  }
+
+  OIDN_API void oidnSetDevice1b(OIDNDevice hDevice, const char* name, bool value)
+  {
+    Device* device = (Device*)hDevice;
+    OIDN_TRY
+      checkHandle(hDevice);
+      OIDN_LOCK(device);
+      device->set1i(name, value);
+    OIDN_CATCH(device)
+  }
+
+  OIDN_API void oidnSetDevice1i(OIDNDevice hDevice, const char* name, int value)
+  {
+    Device* device = (Device*)hDevice;
+    OIDN_TRY
+      checkHandle(hDevice);
+      OIDN_LOCK(device);
+      device->set1i(name, value);
+    OIDN_CATCH(device)
+  }
+
+  OIDN_API bool oidnGetDevice1b(OIDNDevice hDevice, const char* name)
+  {
+    Device* device = (Device*)hDevice;
+    OIDN_TRY
+      checkHandle(hDevice);
+      OIDN_LOCK(device);
+      return device->get1i(name);
+    OIDN_CATCH(device)
+    return false;
+  }
+
+  OIDN_API int oidnGetDevice1i(OIDNDevice hDevice, const char* name)
+  {
+    Device* device = (Device*)hDevice;
+    OIDN_TRY
+      checkHandle(hDevice);
+      OIDN_LOCK(device);
+      return device->get1i(name);
+    OIDN_CATCH(device)
+    return 0;
+  }
+
+  OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice hDevice, OIDNErrorFunction func, void* userPtr)
+  {
+    Device* device = (Device*)hDevice;
+    OIDN_TRY
+      checkHandle(hDevice);
+      OIDN_LOCK(device);
+      device->setErrorFunction((ErrorFunction)func, userPtr);
+    OIDN_CATCH(device)
+  }
+
+  OIDN_API OIDNError oidnGetDeviceError(OIDNDevice hDevice, const char** outMessage)
+  {
+    Device* device = (Device*)hDevice;
+    OIDN_TRY
+      return (OIDNError)Device::getError(device, outMessage);
+    OIDN_CATCH(device)
+    if (outMessage) *outMessage = "";
+    return OIDN_ERROR_UNKNOWN;
+  }
+
+  OIDN_API void oidnCommitDevice(OIDNDevice hDevice)
+  {
+    Device* device = (Device*)hDevice;
+    OIDN_TRY
+      checkHandle(hDevice);
+      OIDN_LOCK(device);
+      device->commit();
+    OIDN_CATCH(device)
+  }
+
+  OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice hDevice, size_t byteSize)
+  {
+    Device* device = (Device*)hDevice;
+    OIDN_TRY
+      checkHandle(hDevice);
+      OIDN_LOCK(device);
+      Ref<Buffer> buffer = device->newBuffer(byteSize);
+      return (OIDNBuffer)buffer.detach();
+    OIDN_CATCH(device)
+    return nullptr;
+  }
+
+  OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice hDevice, void* ptr, size_t byteSize)
+  {
+    Device* device = (Device*)hDevice;
+    OIDN_TRY
+      checkHandle(hDevice);
+      OIDN_LOCK(device);
+      Ref<Buffer> buffer = device->newBuffer(ptr, byteSize);
+      return (OIDNBuffer)buffer.detach();
+    OIDN_CATCH(device)
+    return nullptr;
+  }
+
+  OIDN_API void oidnRetainBuffer(OIDNBuffer hBuffer)
+  {
+    Buffer* buffer = (Buffer*)hBuffer;
+    retainObject(buffer);
+  }
+
+  OIDN_API void oidnReleaseBuffer(OIDNBuffer hBuffer)
+  {
+    Buffer* buffer = (Buffer*)hBuffer;
+    releaseObject(buffer);
+  }
+
+  OIDN_API void* oidnMapBuffer(OIDNBuffer hBuffer, OIDNAccess access, size_t byteOffset, size_t byteSize)
+  {
+    Buffer* buffer = (Buffer*)hBuffer;
+    OIDN_TRY
+      checkHandle(hBuffer);
+      OIDN_LOCK(buffer);
+      return buffer->map(byteOffset, byteSize);
+    OIDN_CATCH(buffer)
+    return nullptr;
+  }
+
+  OIDN_API void oidnUnmapBuffer(OIDNBuffer hBuffer, void* mappedPtr)
+  {
+    Buffer* buffer = (Buffer*)hBuffer;
+    OIDN_TRY
+      checkHandle(hBuffer);
+      OIDN_LOCK(buffer);
+      return buffer->unmap(mappedPtr);
+    OIDN_CATCH(buffer)
+  }
+
+  OIDN_API OIDNFilter oidnNewFilter(OIDNDevice hDevice, const char* type)
+  {
+    Device* device = (Device*)hDevice;
+    OIDN_TRY
+      checkHandle(hDevice);
+      OIDN_LOCK(device);
+      Ref<Filter> filter = device->newFilter(type);
+      return (OIDNFilter)filter.detach();
+    OIDN_CATCH(device)
+    return nullptr;
+  }
+
+  OIDN_API void oidnRetainFilter(OIDNFilter hFilter)
+  {
+    Filter* filter = (Filter*)hFilter;
+    retainObject(filter);
+  }
+
+  OIDN_API void oidnReleaseFilter(OIDNFilter hFilter)
+  {
+    Filter* filter = (Filter*)hFilter;
+    releaseObject(filter);
+  }
+
+  OIDN_API void oidnSetFilterImage(OIDNFilter hFilter, const char* name,
+                                   OIDNBuffer hBuffer, OIDNFormat format,
+                                   size_t width, size_t height,
+                                   size_t byteOffset,
+                                   size_t bytePixelStride, size_t byteRowStride)
+  {
+    Filter* filter = (Filter*)hFilter;
+    OIDN_TRY
+      checkHandle(hFilter);
+      checkHandle(hBuffer);
+      OIDN_LOCK(filter);
+      Ref<Buffer> buffer = (Buffer*)hBuffer;
+      if (buffer->getDevice() != filter->getDevice())
+        throw Exception(Error::InvalidArgument, "the specified objects are bound to different devices");
+      Image data(buffer, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride);
+      filter->setImage(name, data);
+    OIDN_CATCH(filter)
+  }
+
+  OIDN_API void oidnSetSharedFilterImage(OIDNFilter hFilter, const char* name,
+                                         void* ptr, OIDNFormat format,
+                                         size_t width, size_t height,
+                                         size_t byteOffset,
+                                         size_t bytePixelStride, size_t byteRowStride)
+  {
+    Filter* filter = (Filter*)hFilter;
+    OIDN_TRY
+      checkHandle(hFilter);
+      OIDN_LOCK(filter);
+      Image data(ptr, (Format)format, (int)width, (int)height, byteOffset, bytePixelStride, byteRowStride);
+      filter->setImage(name, data);
+    OIDN_CATCH(filter)
+  }
+
+  OIDN_API void oidnSetFilter1b(OIDNFilter hFilter, const char* name, bool value)
+  {
+    Filter* filter = (Filter*)hFilter;
+    OIDN_TRY
+      checkHandle(hFilter);
+      OIDN_LOCK(filter);
+      filter->set1i(name, int(value));
+    OIDN_CATCH(filter)
+  }
+
+  OIDN_API bool oidnGetFilter1b(OIDNFilter hFilter, const char* name)
+  {
+    Filter* filter = (Filter*)hFilter;
+    OIDN_TRY
+      checkHandle(hFilter);
+      OIDN_LOCK(filter);
+      return filter->get1i(name);
+    OIDN_CATCH(filter)
+    return false;
+  }
+
+  OIDN_API void oidnSetFilter1i(OIDNFilter hFilter, const char* name, int value)
+  {
+    Filter* filter = (Filter*)hFilter;
+    OIDN_TRY
+      checkHandle(hFilter);
+      OIDN_LOCK(filter);
+      filter->set1i(name, value);
+    OIDN_CATCH(filter)
+  }
+
+  OIDN_API int oidnGetFilter1i(OIDNFilter hFilter, const char* name)
+  {
+    Filter* filter = (Filter*)hFilter;
+    OIDN_TRY
+      checkHandle(hFilter);
+      OIDN_LOCK(filter);
+      return filter->get1i(name);
+    OIDN_CATCH(filter)
+    return 0;
+  }
+
+  OIDN_API void oidnSetFilter1f(OIDNFilter hFilter, const char* name, float value)
+  {
+    Filter* filter = (Filter*)hFilter;
+    OIDN_TRY
+      checkHandle(hFilter);
+      OIDN_LOCK(filter);
+      filter->set1f(name, value);
+    OIDN_CATCH(filter)
+  }
+
+  OIDN_API float oidnGetFilter1f(OIDNFilter hFilter, const char* name)
+  {
+    Filter* filter = (Filter*)hFilter;
+    OIDN_TRY
+      checkHandle(hFilter);
+      OIDN_LOCK(filter);
+      return filter->get1f(name);
+    OIDN_CATCH(filter)
+    return 0;
+  }
+
+  OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter hFilter, OIDNProgressMonitorFunction func, void* userPtr)
+  {
+    Filter* filter = (Filter*)hFilter;
+    OIDN_TRY
+      checkHandle(hFilter);
+      OIDN_LOCK(filter);
+      filter->setProgressMonitorFunction(func, userPtr);
+    OIDN_CATCH(filter)
+  }
+
+  OIDN_API void oidnCommitFilter(OIDNFilter hFilter)
+  {
+    Filter* filter = (Filter*)hFilter;
+    OIDN_TRY
+      checkHandle(hFilter);
+      OIDN_LOCK(filter);
+      filter->commit();
+    OIDN_CATCH(filter)
+  }
+
+  OIDN_API void oidnExecuteFilter(OIDNFilter hFilter)
+  {
+    Filter* filter = (Filter*)hFilter;
+    OIDN_TRY
+      checkHandle(hFilter);
+      OIDN_LOCK(filter);
+      filter->execute();
+    OIDN_CATCH(filter)
+  }
+
+} // namespace oidn

+ 535 - 0
thirdparty/oidn/core/autoencoder.cpp

@@ -0,0 +1,535 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#include "autoencoder.h"
+
+namespace oidn {
+
+  // --------------------------------------------------------------------------
+  // AutoencoderFilter
+  // --------------------------------------------------------------------------
+
+  AutoencoderFilter::AutoencoderFilter(const Ref<Device>& device)
+    : Filter(device)
+  {
+  }
+
+  void AutoencoderFilter::setImage(const std::string& name, const Image& data)
+  {
+    if (name == "color")
+      color = data;
+    else if (name == "albedo")
+      albedo = data;
+    else if (name == "normal")
+      normal = data;
+    else if (name == "output")
+      output = data;
+
+    dirty = true;
+  }
+
+  void AutoencoderFilter::set1i(const std::string& name, int value)
+  {
+    if (name == "hdr")
+      hdr = value;
+    else if (name == "srgb")
+      srgb = value;
+    else if (name == "maxMemoryMB")
+      maxMemoryMB = value;
+
+    dirty = true;
+  }
+
+  int AutoencoderFilter::get1i(const std::string& name)
+  {
+    if (name == "hdr")
+      return hdr;
+    else if (name == "srgb")
+      return srgb;
+    else if (name == "maxMemoryMB")
+      return maxMemoryMB;
+    else if (name == "alignment")
+      return alignment;
+    else if (name == "overlap")
+      return overlap;
+    else
+      throw Exception(Error::InvalidArgument, "invalid parameter");
+  }
+
+  void AutoencoderFilter::set1f(const std::string& name, float value)
+  {
+    if (name == "hdrScale")
+      hdrScale = value;
+
+    dirty = true;
+  }
+
+  float AutoencoderFilter::get1f(const std::string& name)
+  {
+    if (name == "hdrScale")
+      return hdrScale;
+    else
+      throw Exception(Error::InvalidArgument, "invalid parameter");
+  }
+
+  void AutoencoderFilter::commit()
+  {
+    if (!dirty)
+      return;
+
+    // -- GODOT start --
+    //device->executeTask([&]()
+    //{
+    // GODOT end --
+
+      if (mayiuse(avx512_common))
+        net = buildNet<16>();
+      else
+        net = buildNet<8>();
+
+    // GODOT start --    
+    //});
+    // GODOT end --
+
+    dirty = false;
+  }
+
+  void AutoencoderFilter::execute()
+  {
+    if (dirty)
+      throw Exception(Error::InvalidOperation, "changes to the filter are not committed");
+
+    if (!net)
+      return;
+    // -- GODOT start --
+    //device->executeTask([&]()
+    //{
+    // -- GODOT end --
+      Progress progress;
+      progress.func = progressFunc;
+      progress.userPtr = progressUserPtr;
+      progress.taskCount = tileCountH * tileCountW;
+
+      // Iterate over the tiles
+      int tileIndex = 0;
+
+      for (int i = 0; i < tileCountH; ++i)
+      {
+        const int h = i * (tileH - 2*overlap); // input tile position (including overlap)
+        const int overlapBeginH = i > 0            ? overlap : 0; // overlap on the top
+        const int overlapEndH   = i < tileCountH-1 ? overlap : 0; // overlap on the bottom
+        const int tileH1 = min(H - h, tileH); // input tile size (including overlap)
+        const int tileH2 = tileH1 - overlapBeginH - overlapEndH; // output tile size
+        const int alignOffsetH = tileH - roundUp(tileH1, alignment); // align to the bottom in the tile buffer
+
+        for (int j = 0; j < tileCountW; ++j)
+        {
+          const int w = j * (tileW - 2*overlap); // input tile position (including overlap)
+          const int overlapBeginW = j > 0            ? overlap : 0; // overlap on the left
+          const int overlapEndW   = j < tileCountW-1 ? overlap : 0; // overlap on the right
+          const int tileW1 = min(W - w, tileW); // input tile size (including overlap)
+          const int tileW2 = tileW1 - overlapBeginW - overlapEndW; // output tile size
+          const int alignOffsetW = tileW - roundUp(tileW1, alignment); // align to the right in the tile buffer
+
+          // Set the input tile
+          inputReorder->setTile(h, w,
+                                alignOffsetH, alignOffsetW,
+                                tileH1, tileW1);
+
+          // Set the output tile
+          outputReorder->setTile(alignOffsetH + overlapBeginH, alignOffsetW + overlapBeginW,
+                                 h + overlapBeginH, w + overlapBeginW,
+                                 tileH2, tileW2);
+
+          //printf("Tile: %d %d -> %d %d\n", w+overlapBeginW, h+overlapBeginH, w+overlapBeginW+tileW2, h+overlapBeginH+tileH2);
+
+          // Denoise the tile
+          net->execute(progress, tileIndex);
+
+          // Next tile
+          tileIndex++;
+        }
+      }
+    // -- GODOT start --
+    //});
+    // -- GODOT end --
+  }
+
+  void AutoencoderFilter::computeTileSize()
+  {
+    const int minTileSize = 3*overlap;
+    const int estimatedBytesPerPixel = mayiuse(avx512_common) ? estimatedBytesPerPixel16 : estimatedBytesPerPixel8;
+    const int64_t maxTilePixels = (int64_t(maxMemoryMB)*1024*1024 - estimatedBytesBase) / estimatedBytesPerPixel;
+
+    tileCountH = 1;
+    tileCountW = 1;
+    tileH = roundUp(H, alignment);
+    tileW = roundUp(W, alignment);
+
+    // Divide the image into tiles until the tile size gets below the threshold
+    while (int64_t(tileH) * tileW > maxTilePixels)
+    {
+      if (tileH > minTileSize && tileH > tileW)
+      {
+        tileCountH++;
+        tileH = max(roundUp(ceilDiv(H - 2*overlap, tileCountH), alignment) + 2*overlap, minTileSize);
+      }
+      else if (tileW > minTileSize)
+      {
+        tileCountW++;
+        tileW = max(roundUp(ceilDiv(W - 2*overlap, tileCountW), alignment) + 2*overlap, minTileSize);
+      }
+      else
+        break;
+    }
+
+    // Compute the final number of tiles
+    tileCountH = (H > tileH) ? ceilDiv(H - 2*overlap, tileH - 2*overlap) : 1;
+    tileCountW = (W > tileW) ? ceilDiv(W - 2*overlap, tileW - 2*overlap) : 1;
+
+    if (device->isVerbose(2))
+    {
+      std::cout << "Tile size : " << tileW << "x" << tileH << std::endl;
+      std::cout << "Tile count: " << tileCountW << "x" << tileCountH << std::endl;
+    }
+  }
+
+  template<int K>
+  std::shared_ptr<Executable> AutoencoderFilter::buildNet()
+  {
+    H = color.height;
+    W = color.width;
+
+    // Configure the network
+    int inputC;
+    void* weightPtr;
+
+    if (srgb && hdr)
+      throw Exception(Error::InvalidOperation, "srgb and hdr modes cannot be enabled at the same time");
+
+    if (color && !albedo && !normal && weightData.hdr)
+    {
+      inputC = 3;
+      weightPtr = hdr ? weightData.hdr : weightData.ldr;
+    }
+    else if (color && albedo && !normal && weightData.hdr_alb)
+    {
+      inputC = 6;
+      weightPtr = hdr ? weightData.hdr_alb : weightData.ldr_alb;
+    }
+    else if (color && albedo && normal && weightData.hdr_alb_nrm)
+    {
+      inputC = 9;
+      weightPtr = hdr ? weightData.hdr_alb_nrm : weightData.ldr_alb_nrm;
+    }
+    else
+    {
+      throw Exception(Error::InvalidOperation, "unsupported combination of input features");
+    }
+
+    if (!output)
+      throw Exception(Error::InvalidOperation, "output image not specified");
+
+    if ((color.format != Format::Float3)
+        || (albedo && albedo.format != Format::Float3)
+        || (normal && normal.format != Format::Float3)
+        || (output.format != Format::Float3))
+      throw Exception(Error::InvalidOperation, "unsupported image format");
+
+    if ((albedo && (albedo.width != W || albedo.height != H))
+        || (normal && (normal.width != W || normal.height != H))
+        || (output.width != W || output.height != H))
+      throw Exception(Error::InvalidOperation, "image size mismatch");
+
+    // Compute the tile size
+    computeTileSize();
+
+    // If the image size is zero, there is nothing else to do
+    if (H <= 0 || W <= 0)
+      return nullptr;
+
+    // Parse the weights
+    const auto weightMap = parseTensors(weightPtr);
+
+    // Create the network
+    std::shared_ptr<Network<K>> net = std::make_shared<Network<K>>(device, weightMap);
+
+    // Compute the tensor sizes
+    const auto inputDims        = memory::dims({1, inputC, tileH, tileW});
+    const auto inputReorderDims = net->getInputReorderDims(inputDims, alignment);   //-> concat0
+
+    const auto conv1Dims     = net->getConvDims("conv1", inputReorderDims);         //-> temp0
+    const auto conv1bDims    = net->getConvDims("conv1b", conv1Dims);               //-> temp1
+    const auto pool1Dims     = net->getPoolDims(conv1bDims);                        //-> concat1
+    const auto conv2Dims     = net->getConvDims("conv2", pool1Dims);                //-> temp0
+    const auto pool2Dims     = net->getPoolDims(conv2Dims);                         //-> concat2
+    const auto conv3Dims     = net->getConvDims("conv3", pool2Dims);                //-> temp0
+    const auto pool3Dims     = net->getPoolDims(conv3Dims);                         //-> concat3
+    const auto conv4Dims     = net->getConvDims("conv4", pool3Dims);                //-> temp0
+    const auto pool4Dims     = net->getPoolDims(conv4Dims);                         //-> concat4
+    const auto conv5Dims     = net->getConvDims("conv5", pool4Dims);                //-> temp0
+    const auto pool5Dims     = net->getPoolDims(conv5Dims);                         //-> temp1
+    const auto upsample4Dims = net->getUpsampleDims(pool5Dims);                     //-> concat4
+    const auto concat4Dims   = net->getConcatDims(upsample4Dims, pool4Dims);
+    const auto conv6Dims     = net->getConvDims("conv6", concat4Dims);              //-> temp0
+    const auto conv6bDims    = net->getConvDims("conv6b", conv6Dims);               //-> temp1
+    const auto upsample3Dims = net->getUpsampleDims(conv6bDims);                    //-> concat3
+    const auto concat3Dims   = net->getConcatDims(upsample3Dims, pool3Dims);
+    const auto conv7Dims     = net->getConvDims("conv7", concat3Dims);              //-> temp0
+    const auto conv7bDims    = net->getConvDims("conv7b", conv7Dims);               //-> temp1
+    const auto upsample2Dims = net->getUpsampleDims(conv7bDims);                    //-> concat2
+    const auto concat2Dims   = net->getConcatDims(upsample2Dims, pool2Dims);
+    const auto conv8Dims     = net->getConvDims("conv8", concat2Dims);              //-> temp0
+    const auto conv8bDims    = net->getConvDims("conv8b", conv8Dims);               //-> temp1
+    const auto upsample1Dims = net->getUpsampleDims(conv8bDims);                    //-> concat1
+    const auto concat1Dims   = net->getConcatDims(upsample1Dims, pool1Dims);
+    const auto conv9Dims     = net->getConvDims("conv9", concat1Dims);              //-> temp0
+    const auto conv9bDims    = net->getConvDims("conv9b", conv9Dims);               //-> temp1
+    const auto upsample0Dims = net->getUpsampleDims(conv9bDims);                    //-> concat0
+    const auto concat0Dims   = net->getConcatDims(upsample0Dims, inputReorderDims);
+    const auto conv10Dims    = net->getConvDims("conv10", concat0Dims);             //-> temp0
+    const auto conv10bDims   = net->getConvDims("conv10b", conv10Dims);             //-> temp1
+    const auto conv11Dims    = net->getConvDims("conv11", conv10bDims);             //-> temp0
+
+    const auto outputDims = memory::dims({1, 3, tileH, tileW});
+
+    // Allocate two temporary ping-pong buffers to decrease memory usage
+    const auto temp0Dims = getMaxTensorDims({
+      conv1Dims,
+      conv2Dims,
+      conv3Dims,
+      conv4Dims,
+      conv5Dims,
+      conv6Dims,
+      conv7Dims,
+      conv8Dims,
+      conv9Dims,
+      conv10Dims,
+      conv11Dims
+    });
+
+    const auto temp1Dims = getMaxTensorDims({
+      conv1bDims,
+      pool5Dims,
+      conv6bDims,
+      conv7bDims,
+      conv8bDims,
+      conv9bDims,
+      conv10bDims,
+    });
+
+    auto temp0 = net->allocTensor(temp0Dims);
+    auto temp1 = net->allocTensor(temp1Dims);
+
+    // Allocate enough memory to hold the concat outputs. Then use the first
+    // half to hold the previous conv output and the second half to hold the
+    // pool/orig image output. This works because everything is C dimension
+    // outermost, padded to K floats, and all the concats are on the C dimension.
+    auto concat0Dst = net->allocTensor(concat0Dims);
+    auto concat1Dst = net->allocTensor(concat1Dims);
+    auto concat2Dst = net->allocTensor(concat2Dims);
+    auto concat3Dst = net->allocTensor(concat3Dims);
+    auto concat4Dst = net->allocTensor(concat4Dims);
+
+    // Transfer function
+    std::shared_ptr<TransferFunction> transferFunc = makeTransferFunc();
+
+    // Autoexposure
+    if (auto tf = std::dynamic_pointer_cast<HDRTransferFunction>(transferFunc))
+    {
+      if (isnan(hdrScale))
+        net->addAutoexposure(color, tf);
+      else
+        tf->setExposure(hdrScale);
+    }
+
+    // Input reorder
+    auto inputReorderDst = net->castTensor(inputReorderDims, concat0Dst, upsample0Dims);
+    inputReorder = net->addInputReorder(color, albedo, normal,
+                                        transferFunc,
+                                        alignment, inputReorderDst);
+
+    // conv1
+    auto conv1 = net->addConv("conv1", inputReorder->getDst(), temp0);
+
+    // conv1b
+    auto conv1b = net->addConv("conv1b", conv1->getDst(), temp1);
+
+    // pool1
+    // Adjust pointer for pool1 to eliminate concat1
+    auto pool1Dst = net->castTensor(pool1Dims, concat1Dst, upsample1Dims);
+    auto pool1 = net->addPool(conv1b->getDst(), pool1Dst);
+
+    // conv2
+    auto conv2 = net->addConv("conv2", pool1->getDst(), temp0);
+
+    // pool2
+    // Adjust pointer for pool2 to eliminate concat2
+    auto pool2Dst = net->castTensor(pool2Dims, concat2Dst, upsample2Dims);
+    auto pool2 = net->addPool(conv2->getDst(), pool2Dst);
+
+    // conv3
+    auto conv3 = net->addConv("conv3", pool2->getDst(), temp0);
+
+    // pool3
+    // Adjust pointer for pool3 to eliminate concat3
+    auto pool3Dst = net->castTensor(pool3Dims, concat3Dst, upsample3Dims);
+    auto pool3 = net->addPool(conv3->getDst(), pool3Dst);
+
+    // conv4
+    auto conv4 = net->addConv("conv4", pool3->getDst(), temp0);
+
+    // pool4
+    // Adjust pointer for pool4 to eliminate concat4
+    auto pool4Dst = net->castTensor(pool4Dims, concat4Dst, upsample4Dims);
+    auto pool4 = net->addPool(conv4->getDst(), pool4Dst);
+
+    // conv5
+    auto conv5 = net->addConv("conv5", pool4->getDst(), temp0);
+
+    // pool5
+    auto pool5 = net->addPool(conv5->getDst(), temp1);
+
+    // upsample4
+    auto upsample4Dst = net->castTensor(upsample4Dims, concat4Dst);
+    auto upsample4 = net->addUpsample(pool5->getDst(), upsample4Dst);
+
+    // conv6
+    auto conv6 = net->addConv("conv6", concat4Dst, temp0);
+
+    // conv6b
+    auto conv6b = net->addConv("conv6b", conv6->getDst(), temp1);
+
+    // upsample3
+    auto upsample3Dst = net->castTensor(upsample3Dims, concat3Dst);
+    auto upsample3 = net->addUpsample(conv6b->getDst(), upsample3Dst);
+
+    // conv7
+    auto conv7 = net->addConv("conv7", concat3Dst, temp0);
+
+    // conv7b
+    auto conv7b = net->addConv("conv7b", conv7->getDst(), temp1);
+
+    // upsample2
+    auto upsample2Dst = net->castTensor(upsample2Dims, concat2Dst);
+    auto upsample2 = net->addUpsample(conv7b->getDst(), upsample2Dst);
+
+    // conv8
+    auto conv8 = net->addConv("conv8", concat2Dst, temp0);
+
+    // conv8b
+    auto conv8b = net->addConv("conv8b", conv8->getDst(), temp1);
+
+    // upsample1
+    auto upsample1Dst = net->castTensor(upsample1Dims, concat1Dst);
+    auto upsample1 = net->addUpsample(conv8b->getDst(), upsample1Dst);
+
+    // conv9
+    auto conv9 = net->addConv("conv9", concat1Dst, temp0);
+
+    // conv9b
+    auto conv9b = net->addConv("conv9b", conv9->getDst(), temp1);
+
+    // upsample0
+    auto upsample0Dst = net->castTensor(upsample0Dims, concat0Dst);
+    auto upsample0 = net->addUpsample(conv9b->getDst(), upsample0Dst);
+
+    // conv10
+    auto conv10 = net->addConv("conv10", concat0Dst, temp0);
+
+    // conv10b
+    auto conv10b = net->addConv("conv10b", conv10->getDst(), temp1);
+
+    // conv11
+    auto conv11 = net->addConv("conv11", conv10b->getDst(), temp0, false /* no relu */);
+
+    // Output reorder
+    outputReorder = net->addOutputReorder(conv11->getDst(), transferFunc, output);
+
+    net->finalize();
+    return net;
+  }
+
+  std::shared_ptr<TransferFunction> AutoencoderFilter::makeTransferFunc()
+  {
+    if (hdr)
+      return std::make_shared<PQXTransferFunction>();
+    else if (srgb)
+      return std::make_shared<LinearTransferFunction>();
+    else
+      return std::make_shared<GammaTransferFunction>();
+  }
+
+// -- GODOT start --
+// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
+#if 0
+// -- GODOT end --
+
+  // --------------------------------------------------------------------------
+  // RTFilter
+  // --------------------------------------------------------------------------
+
+  namespace weights
+  {
+    // LDR
+    extern unsigned char rt_ldr[];         // color
+    extern unsigned char rt_ldr_alb[];     // color, albedo
+    extern unsigned char rt_ldr_alb_nrm[]; // color, albedo, normal
+
+    // HDR
+    extern unsigned char rt_hdr[];         // color
+    extern unsigned char rt_hdr_alb[];     // color, albedo
+    extern unsigned char rt_hdr_alb_nrm[]; // color, albedo, normal
+  }
+
+  RTFilter::RTFilter(const Ref<Device>& device)
+    : AutoencoderFilter(device)
+  {
+    weightData.ldr         = weights::rt_ldr;
+    weightData.ldr_alb     = weights::rt_ldr_alb;
+    weightData.ldr_alb_nrm = weights::rt_ldr_alb_nrm;
+    weightData.hdr         = weights::rt_hdr;
+    weightData.hdr_alb     = weights::rt_hdr_alb;
+    weightData.hdr_alb_nrm = weights::rt_hdr_alb_nrm;
+  }
+// -- GODOT start --
+#endif
+// -- GODOT end --
+
+  // --------------------------------------------------------------------------
+  // RTLightmapFilter
+  // --------------------------------------------------------------------------
+
+  namespace weights
+  {
+    // HDR
+    extern unsigned char rtlightmap_hdr[]; // color
+  }
+
+  RTLightmapFilter::RTLightmapFilter(const Ref<Device>& device)
+    : AutoencoderFilter(device)
+  {
+    weightData.hdr = weights::rtlightmap_hdr;
+
+    hdr = true;
+  }
+
+  std::shared_ptr<TransferFunction> RTLightmapFilter::makeTransferFunc()
+  {
+    return std::make_shared<LogTransferFunction>();
+  }
+
+} // namespace oidn

+ 120 - 0
thirdparty/oidn/core/autoencoder.h

@@ -0,0 +1,120 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include "filter.h"
+#include "network.h"
+#include "transfer_function.h"
+
+namespace oidn {
+
+  // --------------------------------------------------------------------------
+  // AutoencoderFilter - Direct-predicting autoencoder
+  // --------------------------------------------------------------------------
+
+  class AutoencoderFilter : public Filter
+  {
+  protected:
+    static constexpr int alignment       = 32;  // required spatial alignment in pixels (padding may be necessary)
+    static constexpr int receptiveField  = 222; // receptive field in pixels
+    static constexpr int overlap         = roundUp(receptiveField / 2, alignment); // required spatial overlap between tiles in pixels
+
+    static constexpr int estimatedBytesBase       = 16*1024*1024; // estimated base memory usage
+    static constexpr int estimatedBytesPerPixel8  = 889;          // estimated memory usage per pixel for K=8
+    static constexpr int estimatedBytesPerPixel16 = 2185;         // estimated memory usage per pixel for K=16
+
+    Image color;
+    Image albedo;
+    Image normal;
+    Image output;
+    bool hdr = false;
+    float hdrScale = std::numeric_limits<float>::quiet_NaN();
+    bool srgb = false;
+    int maxMemoryMB = 6000; // approximate maximum memory usage in MBs
+
+    int H = 0;          // image height
+    int W = 0;          // image width
+    int tileH = 0;      // tile height
+    int tileW = 0;      // tile width
+    int tileCountH = 1; // number of tiles in H dimension
+    int tileCountW = 1; // number of tiles in W dimension
+
+    std::shared_ptr<Executable> net;
+    std::shared_ptr<Node> inputReorder;
+    std::shared_ptr<Node> outputReorder;
+
+    struct
+    {
+      void* ldr         = nullptr;
+      void* ldr_alb     = nullptr;
+      void* ldr_alb_nrm = nullptr;
+      void* hdr         = nullptr;
+      void* hdr_alb     = nullptr;
+      void* hdr_alb_nrm = nullptr;
+    } weightData;
+
+    explicit AutoencoderFilter(const Ref<Device>& device);
+    virtual std::shared_ptr<TransferFunction> makeTransferFunc();
+
+  public:
+    void setImage(const std::string& name, const Image& data) override;
+    void set1i(const std::string& name, int value) override;
+    int get1i(const std::string& name) override;
+    void set1f(const std::string& name, float value) override;
+    float get1f(const std::string& name) override;
+
+    void commit() override;
+    void execute() override;
+
+  private:
+    void computeTileSize();
+
+    template<int K>
+    std::shared_ptr<Executable> buildNet();
+
+    bool isCommitted() const { return bool(net); }
+  };
+
+  // --------------------------------------------------------------------------
+  // RTFilter - Generic ray tracing denoiser
+  // --------------------------------------------------------------------------
+
+// -- GODOT start --
+// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
+#if 0
+// -- GODOT end --
+  class RTFilter : public AutoencoderFilter
+  {
+  public:
+    explicit RTFilter(const Ref<Device>& device);
+  };
+// -- GODOT start --
+#endif
+// -- GODOT end --
+
+  // --------------------------------------------------------------------------
+  // RTLightmapFilter - Ray traced lightmap denoiser
+  // --------------------------------------------------------------------------
+
+  class RTLightmapFilter : public AutoencoderFilter
+  {
+  public:
+    explicit RTLightmapFilter(const Ref<Device>& device);
+    std::shared_ptr<TransferFunction> makeTransferFunc() override;
+  };
+
+} // namespace oidn

+ 75 - 0
thirdparty/oidn/core/buffer.h

@@ -0,0 +1,75 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include "common.h"
+#include "device.h"
+
+namespace oidn {
+
+  class Device;
+
+  // Buffer which may or may not own its data
+  class Buffer : public RefCount
+  {
+  private:
+    char* ptr;
+    size_t byteSize;
+    bool shared;
+    Ref<Device> device;
+
+  public:
+    __forceinline Buffer(const Ref<Device>& device, size_t size)
+      : ptr((char*)alignedMalloc(size, 64)),
+        byteSize(size),
+        shared(false),
+        device(device) {}
+
+    __forceinline Buffer(const Ref<Device>& device, void* data, size_t size)
+      : ptr((char*)data),
+        byteSize(size),
+        shared(true),
+        device(device)
+    {
+      if (data == nullptr)
+        throw Exception(Error::InvalidArgument, "buffer pointer null");
+    }
+
+    __forceinline ~Buffer()
+    {
+      if (!shared)
+        alignedFree(ptr);
+    }
+
+    __forceinline char* data() { return ptr; }
+    __forceinline const char* data() const { return ptr; }
+    __forceinline size_t size() const { return byteSize; }
+
+    void* map(size_t offset, size_t size)
+    {
+      if (offset + size > byteSize)
+        throw Exception(Error::InvalidArgument, "buffer region out of range");
+
+      return ptr + offset;
+    }
+
+    void unmap(void* mappedPtr) {}
+
+    Device* getDevice() { return device.get(); }
+  };
+
+} // namespace oidn

+ 136 - 0
thirdparty/oidn/core/common.h

@@ -0,0 +1,136 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include "common/platform.h"
+
+#include "mkl-dnn/include/mkldnn.hpp"
+#include "mkl-dnn/include/mkldnn_debug.h"
+#include "mkl-dnn/src/common/mkldnn_thread.hpp"
+#include "mkl-dnn/src/common/type_helpers.hpp"
+#include "mkl-dnn/src/cpu/jit_generator.hpp"
+
+#include "common/ref.h"
+#include "common/exception.h"
+#include "common/thread.h"
+// -- GODOT start --
+//#include "common/tasking.h"
+// -- GODOT end --
+#include "math.h"
+
+namespace oidn {
+
+  using namespace mkldnn;
+  using namespace mkldnn::impl::cpu;
+  using mkldnn::impl::parallel_nd;
+  using mkldnn::impl::memory_desc_matches_tag;
+
+
+  inline size_t getFormatBytes(Format format)
+  {
+    switch (format)
+    {
+    case Format::Undefined: return 1;
+    case Format::Float:     return sizeof(float);
+    case Format::Float2:    return sizeof(float)*2;
+    case Format::Float3:    return sizeof(float)*3;
+    case Format::Float4:    return sizeof(float)*4;
+    }
+    assert(0);
+    return 0;
+  }
+
+
+  inline memory::dims getTensorDims(const std::shared_ptr<memory>& mem)
+  {
+    const mkldnn_memory_desc_t& desc = mem->get_desc().data;
+    return memory::dims(&desc.dims[0], &desc.dims[desc.ndims]);
+  }
+
+  inline memory::data_type getTensorType(const std::shared_ptr<memory>& mem)
+  {
+    const mkldnn_memory_desc_t& desc = mem->get_desc().data;
+    return memory::data_type(desc.data_type);
+  }
+
+  // Returns the number of values in a tensor
+  inline size_t getTensorSize(const memory::dims& dims)
+  {
+    size_t res = 1;
+    for (int i = 0; i < (int)dims.size(); ++i)
+      res *= dims[i];
+    return res;
+  }
+
+  inline memory::dims getMaxTensorDims(const std::vector<memory::dims>& dims)
+  {
+    memory::dims result;
+    size_t maxSize = 0;
+
+    for (const auto& d : dims)
+    {
+      const size_t size = getTensorSize(d);
+      if (size > maxSize)
+      {
+        result = d;
+        maxSize = size;
+      }
+    }
+
+    return result;
+  }
+
+  inline size_t getTensorSize(const std::shared_ptr<memory>& mem)
+  {
+    return getTensorSize(getTensorDims(mem));
+  }
+
+
+  template<int K>
+  inline int getPadded(int dim)
+  {
+    return (dim + (K-1)) & ~(K-1);
+  }
+
+  template<int K>
+  inline memory::dims getPadded_nchw(const memory::dims& dims)
+  {
+    assert(dims.size() == 4);
+    memory::dims padDims = dims;
+    padDims[1] = getPadded<K>(dims[1]); // pad C
+    return padDims;
+  }
+
+
+  template<int K>
+  struct BlockedFormat;
+
+  template<>
+  struct BlockedFormat<8>
+  {
+    static constexpr memory::format_tag nChwKc   = memory::format_tag::nChw8c;
+    static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw8i8o;
+  };
+
+  template<>
+  struct BlockedFormat<16>
+  {
+    static constexpr memory::format_tag nChwKc   = memory::format_tag::nChw16c;
+    static constexpr memory::format_tag OIhwKiKo = memory::format_tag::OIhw16i16o;
+  };
+
+} // namespace oidn

+ 238 - 0
thirdparty/oidn/core/device.cpp

@@ -0,0 +1,238 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#include "device.h"
+#include "autoencoder.h"
+
+namespace oidn {
+
+  thread_local Device::ErrorState Device::globalError;
+
+  Device::Device()
+  {
+    if (!mayiuse(sse41))
+      throw Exception(Error::UnsupportedHardware, "SSE4.1 support is required at minimum");
+  }
+
+  Device::~Device()
+  {
+    // -- GODOT start --
+    //observer.reset();
+    // -- GODOT end --
+  }
+
+  void Device::setError(Device* device, Error code, const std::string& message)
+  {
+    // Update the stored error only if the previous error was queried
+    if (device)
+    {
+      ErrorState& curError = device->error.get();
+
+      if (curError.code == Error::None)
+      {
+        curError.code = code;
+        curError.message = message;
+      }
+
+      // Print the error message in verbose mode
+      if (device->isVerbose())
+        std::cerr << "Error: " << message << std::endl;
+
+      // Call the error callback function
+      ErrorFunction errorFunc;
+      void* errorUserPtr;
+
+      {
+        std::lock_guard<std::mutex> lock(device->mutex);
+        errorFunc = device->errorFunc;
+        errorUserPtr = device->errorUserPtr;
+      }
+
+      if (errorFunc)
+        errorFunc(errorUserPtr, code, (code == Error::None) ? nullptr : message.c_str());
+    }
+    else
+    {
+      if (globalError.code == Error::None)
+      {
+        globalError.code = code;
+        globalError.message = message;
+      }
+    }
+  }
+
+  Error Device::getError(Device* device, const char** outMessage)
+  {
+    // Return and clear the stored error code, but keep the error message so pointers to it will
+    // remain valid until the next getError call
+    if (device)
+    {
+      ErrorState& curError = device->error.get();
+      const Error code = curError.code;
+      if (outMessage)
+        *outMessage = (code == Error::None) ? nullptr : curError.message.c_str();
+      curError.code = Error::None;
+      return code;
+    }
+    else
+    {
+      const Error code = globalError.code;
+      if (outMessage)
+        *outMessage = (code == Error::None) ? nullptr : globalError.message.c_str();
+      globalError.code = Error::None;
+      return code;
+    }
+  }
+
+  void Device::setErrorFunction(ErrorFunction func, void* userPtr)
+  {
+    errorFunc = func;
+    errorUserPtr = userPtr;
+  }
+
+  int Device::get1i(const std::string& name)
+  {
+    if (name == "numThreads")
+      return numThreads;
+    else if (name == "setAffinity")
+      return setAffinity;
+    else if (name == "verbose")
+      return verbose;
+    else if (name == "version")
+      return OIDN_VERSION;
+    else if (name == "versionMajor")
+      return OIDN_VERSION_MAJOR;
+    else if (name == "versionMinor")
+      return OIDN_VERSION_MINOR;
+    else if (name == "versionPatch")
+      return OIDN_VERSION_PATCH;
+    else
+      throw Exception(Error::InvalidArgument, "invalid parameter");
+  }
+
+  void Device::set1i(const std::string& name, int value)
+  {
+    if (name == "numThreads")
+      numThreads = value;
+    else if (name == "setAffinity")
+      setAffinity = value;
+    else if (name == "verbose")
+    {
+      verbose = value;
+      error.verbose = value;
+    }
+
+    dirty = true;
+  }
+
+  void Device::commit()
+  {
+    if (isCommitted())
+      throw Exception(Error::InvalidOperation, "device can be committed only once");
+
+    // -- GODOT start --
+    #if 0
+    // -- GODOT end --
+    // Get the optimal thread affinities
+    if (setAffinity)
+    {
+      affinity = std::make_shared<ThreadAffinity>(1, verbose); // one thread per core
+      if (affinity->getNumThreads() == 0)
+        affinity.reset();
+    }
+
+    // Create the task arena
+    const int maxNumThreads = affinity ? affinity->getNumThreads() : tbb::this_task_arena::max_concurrency();
+    numThreads = (numThreads > 0) ? min(numThreads, maxNumThreads) : maxNumThreads;
+    arena = std::make_shared<tbb::task_arena>(numThreads);
+
+    // Automatically set the thread affinities
+    if (affinity)
+      observer = std::make_shared<PinningObserver>(affinity, *arena);
+    // -- GODOT start --
+    #endif
+    numThreads = 1;
+    // -- GODOT end --
+    dirty = false;
+
+    if (isVerbose())
+      print();
+  }
+
+  void Device::checkCommitted()
+  {
+    if (dirty)
+      throw Exception(Error::InvalidOperation, "changes to the device are not committed");
+  }
+
+  Ref<Buffer> Device::newBuffer(size_t byteSize)
+  {
+    checkCommitted();
+    return makeRef<Buffer>(Ref<Device>(this), byteSize);
+  }
+
+  Ref<Buffer> Device::newBuffer(void* ptr, size_t byteSize)
+  {
+    checkCommitted();
+    return makeRef<Buffer>(Ref<Device>(this), ptr, byteSize);
+  }
+
+  Ref<Filter> Device::newFilter(const std::string& type)
+  {
+    checkCommitted();
+
+    if (isVerbose())
+      std::cout << "Filter: " << type << std::endl;
+
+    Ref<Filter> filter;
+
+// -- GODOT start --
+// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
+#if 0
+// -- GODOT end --
+    if (type == "RT")
+      filter = makeRef<RTFilter>(Ref<Device>(this));
+// -- GODOT start --
+// Godot doesn't need Raytracing filters. Removing them saves space in the weights files.
+#endif
+    if (type == "RTLightmap")
+// -- GODOT end --
+      filter = makeRef<RTLightmapFilter>(Ref<Device>(this));
+    else
+      throw Exception(Error::InvalidArgument, "unknown filter type");
+
+    return filter;
+  }
+
+  void Device::print()
+  {
+    std::cout << std::endl;
+
+    std::cout << "Intel(R) Open Image Denoise " << OIDN_VERSION_STRING << std::endl;
+    std::cout << "  Compiler: " << getCompilerName() << std::endl;
+    std::cout << "  Build   : " << getBuildName() << std::endl;
+    std::cout << "  Platform: " << getPlatformName() << std::endl;
+
+// -- GODOT start --
+//    std::cout << "  Tasking :";
+//    std::cout << " TBB" << TBB_VERSION_MAJOR << "." << TBB_VERSION_MINOR;
+//    std::cout << " TBB_header_interface_" << TBB_INTERFACE_VERSION << " TBB_lib_interface_" << tbb::TBB_runtime_interface_version();
+//    std::cout << std::endl;
+// -- GODOT end --
+    std::cout << std::endl;
+  }
+
+} // namespace oidn

+ 102 - 0
thirdparty/oidn/core/device.h

@@ -0,0 +1,102 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include "common.h"
+
+namespace oidn {
+
+  class Buffer;
+  class Filter;
+
+  class Device : public RefCount, public Verbose
+  {
+  private:
+    // Thread-safety
+    std::mutex mutex;
+
+    // Error handling
+    struct ErrorState
+    {
+      Error code = Error::None;
+      std::string message;
+    };
+
+    static thread_local ErrorState globalError;
+    ThreadLocal<ErrorState> error;
+    ErrorFunction errorFunc = nullptr;
+    void* errorUserPtr = nullptr;
+
+// -- GODOT start --
+//    // Tasking
+//    std::shared_ptr<tbb::task_arena> arena;
+//    std::shared_ptr<PinningObserver> observer;
+//    std::shared_ptr<ThreadAffinity> affinity;
+// -- GODOT end --
+
+    // Parameters
+    int numThreads = 0; // autodetect by default
+    bool setAffinity = true;
+
+    bool dirty = true;
+
+  public:
+    Device();
+    ~Device();
+
+    static void setError(Device* device, Error code, const std::string& message);
+    static Error getError(Device* device, const char** outMessage);
+
+    void setErrorFunction(ErrorFunction func, void* userPtr);
+
+    int get1i(const std::string& name);
+    void set1i(const std::string& name, int value);
+
+    void commit();
+
+// -- GODOT start --
+//    template<typename F>
+//    void executeTask(F& f)
+//    {
+//      arena->execute(f);
+//    }
+
+//    template<typename F>
+//    void executeTask(const F& f)
+//    {
+//      arena->execute(f);
+//    }
+// -- GODOT end --
+
+    Ref<Buffer> newBuffer(size_t byteSize);
+    Ref<Buffer> newBuffer(void* ptr, size_t byteSize);
+    Ref<Filter> newFilter(const std::string& type);
+
+    __forceinline Device* getDevice() { return this; }
+    __forceinline std::mutex& getMutex() { return mutex; }
+
+  private:
+// -- GODOT start --
+  //bool isCommitted() const { return bool(arena); }
+  bool isCommitted() const { return false; }
+// -- GODOT end --
+    void checkCommitted();
+
+    void print();
+  };
+
+} // namespace oidn

+ 27 - 0
thirdparty/oidn/core/filter.cpp

@@ -0,0 +1,27 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#include "filter.h"
+
+namespace oidn {
+
+  void Filter::setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr)
+  {
+    progressFunc = func;
+    progressUserPtr = userPtr;
+  }
+
+} // namespace oidn

+ 52 - 0
thirdparty/oidn/core/filter.h

@@ -0,0 +1,52 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include "common.h"
+#include "device.h"
+#include "image.h"
+
+namespace oidn {
+
+  class Filter : public RefCount
+  {
+  protected:
+    Ref<Device> device;
+
+    ProgressMonitorFunction progressFunc = nullptr;
+    void* progressUserPtr = nullptr;
+
+    bool dirty = true;
+
+  public:
+    explicit Filter(const Ref<Device>& device) : device(device) {}
+
+    virtual void setImage(const std::string& name, const Image& data) = 0;
+    virtual void set1i(const std::string& name, int value) = 0;
+    virtual int get1i(const std::string& name) = 0;
+    virtual void set1f(const std::string& name, float value) = 0;
+    virtual float get1f(const std::string& name) = 0;
+
+    void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr);
+
+    virtual void commit() = 0;
+    virtual void execute() = 0;
+
+    Device* getDevice() { return device.get(); }
+  };
+
+} // namespace oidn

+ 111 - 0
thirdparty/oidn/core/image.h

@@ -0,0 +1,111 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include "common.h"
+#include "buffer.h"
+
+namespace oidn {
+
+  struct Image
+  {
+    static constexpr int maxSize = 65536;
+
+    char* ptr;              // pointer to the first pixel
+    int width;              // width in number of pixels
+    int height;             // height in number of pixels
+    size_t bytePixelStride; // pixel stride in number of *bytes*
+    size_t rowStride;       // row stride in number of *pixel strides*
+    Format format;          // pixel format
+    Ref<Buffer> buffer;     // buffer containing the image data
+
+    Image() : ptr(nullptr), width(0), height(0), bytePixelStride(0), rowStride(0), format(Format::Undefined) {}
+
+    Image(void* ptr, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride)
+    {
+      if (ptr == nullptr)
+        throw Exception(Error::InvalidArgument, "buffer pointer null");
+
+      init((char*)ptr + byteOffset, format, width, height, inBytePixelStride, inByteRowStride);
+    }
+
+    Image(const Ref<Buffer>& buffer, Format format, int width, int height, size_t byteOffset, size_t inBytePixelStride, size_t inByteRowStride)
+    {
+      init(buffer->data() + byteOffset, format, width, height, inBytePixelStride, inByteRowStride);
+
+      if (byteOffset + height * rowStride * bytePixelStride > buffer->size())
+        throw Exception(Error::InvalidArgument, "buffer region out of range");
+    }
+
+    void init(char* ptr, Format format, int width, int height, size_t inBytePixelStride, size_t inByteRowStride)
+    {
+      assert(width >= 0);
+      assert(height >= 0);
+      if (width > maxSize || height > maxSize)
+        throw Exception(Error::InvalidArgument, "image size too large");
+
+      this->ptr = ptr;
+      this->width = width;
+      this->height = height;
+
+      const size_t pixelSize = getFormatBytes(format);
+      if (inBytePixelStride != 0)
+      {
+        if (inBytePixelStride < pixelSize)
+          throw Exception(Error::InvalidArgument, "pixel stride smaller than pixel size");
+
+        this->bytePixelStride = inBytePixelStride;
+      }
+      else
+      {
+        this->bytePixelStride = pixelSize;
+      }
+
+      if (inByteRowStride != 0)
+      {
+        if (inByteRowStride < width * this->bytePixelStride)
+          throw Exception(Error::InvalidArgument, "row stride smaller than width * pixel stride");
+        if (inByteRowStride % this->bytePixelStride != 0)
+          throw Exception(Error::InvalidArgument, "row stride not integer multiple of pixel stride");
+
+        this->rowStride = inByteRowStride / this->bytePixelStride;
+      }
+      else
+      {
+        this->rowStride = width;
+      }
+
+      this->format = format;
+    }
+
+    __forceinline char* get(int y, int x)
+    {
+      return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride);
+    }
+
+    __forceinline const char* get(int y, int x) const
+    {
+      return ptr + ((size_t(y) * rowStride + size_t(x)) * bytePixelStride);
+    }
+
+    operator bool() const
+    {
+      return ptr != nullptr;
+    }
+  };
+
+} // namespace oidn

+ 232 - 0
thirdparty/oidn/core/input_reorder.h

@@ -0,0 +1,232 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include "node.h"
+#include "image.h"
+
+namespace oidn {
+
+  // Input reorder node
+  template<int K, class TransferFunction>
+  class InputReorderNode : public Node
+  {
+  private:
+    // Source
+    Image color;
+    Image albedo;
+    Image normal;
+
+    // Destination
+    std::shared_ptr<memory> dst;
+    float* dstPtr;
+    int C2;
+    int H2;
+    int W2;
+
+    // Tile
+    int h1Begin;
+    int w1Begin;
+    int h2Begin;
+    int w2Begin;
+    int H;
+    int W;
+
+    std::shared_ptr<TransferFunction> transferFunc;
+
+  public:
+    InputReorderNode(const Image& color,
+                     const Image& albedo,
+                     const Image& normal,
+                     const std::shared_ptr<memory>& dst,
+                     const std::shared_ptr<TransferFunction>& transferFunc)
+      : color(color), albedo(albedo), normal(normal),
+        dst(dst),
+        h1Begin(0), w1Begin(0),
+        H(color.height), W(color.width),
+        transferFunc(transferFunc)
+    {
+      const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
+      assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
+      assert(dstDesc.ndims == 4);
+      assert(dstDesc.data_type == memory::data_type::f32);
+      assert(dstDesc.dims[0] == 1);
+      //assert(dstDesc.dims[1] >= getPadded<K>(C1));
+
+      dstPtr = (float*)dst->get_data_handle();
+      C2 = dstDesc.dims[1];
+      H2 = dstDesc.dims[2];
+      W2 = dstDesc.dims[3];
+    }
+
+    void setTile(int h1, int w1, int h2, int w2, int H, int W) override
+    {
+      h1Begin = h1;
+      w1Begin = w1;
+      h2Begin = h2;
+      w2Begin = w2;
+      this->H = H;
+      this->W = W;
+    }
+
+    void execute(stream& sm) override
+    {
+      assert(H + h1Begin <= color.height);
+      assert(W + w1Begin <= color.width);
+      assert(H + h2Begin <= H2);
+      assert(W + w2Begin <= W2);
+
+      parallel_nd(H2, [&](int h2)
+      {
+        const int h = h2 - h2Begin;
+
+        if (h >= 0 && h < H)
+        {
+          const int h1 = h + h1Begin;
+
+          // Zero pad
+          for (int w2 = 0; w2 < w2Begin; ++w2)
+          {
+            int c = 0;
+            while (c < C2)
+              store(h2, w2, c, 0.f);
+          }
+
+          // Reorder
+          for (int w = 0; w < W; ++w)
+          {
+            const int w1 = w + w1Begin;
+            const int w2 = w + w2Begin;
+
+            int c = 0;
+            storeColor(h2, w2, c, (float*)color.get(h1, w1));
+            if (albedo)
+              storeAlbedo(h2, w2, c, (float*)albedo.get(h1, w1));
+            if (normal)
+              storeNormal(h2, w2, c, (float*)normal.get(h1, w1));
+            while (c < C2)
+              store(h2, w2, c, 0.f);
+          }
+
+          // Zero pad
+          for (int w2 = W + w2Begin; w2 < W2; ++w2)
+          {
+            int c = 0;
+            while (c < C2)
+              store(h2, w2, c, 0.f);
+          }
+        }
+        else
+        {
+          // Zero pad
+          for (int w2 = 0; w2 < W2; ++w2)
+          {
+            int c = 0;
+            while (c < C2)
+              store(h2, w2, c, 0.f);
+          }
+        }
+      });
+    }
+
+    std::shared_ptr<memory> getDst() const override { return dst; }
+
+  private:
+    // Stores a single value
+    __forceinline void store(int h, int w, int& c, float value)
+    {
+      // Destination is in nChwKc format
+      float* dst_c = dstPtr + (H2*W2*K*(c/K)) + h*W2*K + w*K + (c%K);
+      *dst_c = value;
+      c++;
+    }
+
+    // Stores a color
+    __forceinline void storeColor(int h, int w, int& c, const float* values)
+    {
+      #pragma unroll
+      for (int i = 0; i < 3; ++i)
+      {
+        // Load the value
+        float x = values[i];
+
+        // Sanitize the value
+        x = maxSafe(x, 0.f);
+
+        // Apply the transfer function
+        x = transferFunc->forward(x);
+
+        // Store the value
+        store(h, w, c, x);
+      }
+    }
+
+    // Stores an albedo
+    __forceinline void storeAlbedo(int h, int w, int& c, const float* values)
+    {
+      #pragma unroll
+      for (int i = 0; i < 3; ++i)
+      {
+        // Load the value
+        float x = values[i];
+
+        // Sanitize the value
+        x = clampSafe(x, 0.f, 1.f);
+
+        // Store the value
+        store(h, w, c, x);
+      }
+    }
+
+    // Stores a normal
+    __forceinline void storeNormal(int h, int w, int& c, const float* values)
+    {
+      // Load the normal
+      float x = values[0];
+      float y = values[1];
+      float z = values[2];
+
+      // Compute the length of the normal
+      const float lengthSqr = sqr(x) + sqr(y) + sqr(z);
+
+      // Normalize the normal and transform it to [0..1]
+      if (isfinite(lengthSqr))
+      {
+        const float invLength = (lengthSqr > minVectorLengthSqr) ? rsqrt(lengthSqr) : 1.f;
+
+        const float scale  = invLength * 0.5f;
+        const float offset = 0.5f;
+
+        x = x * scale + offset;
+        y = y * scale + offset;
+        z = z * scale + offset;
+      }
+      else
+      {
+        x = 0.f;
+        y = 0.f;
+        z = 0.f;
+      }
+
+      // Store the normal
+      store(h, w, c, x);
+      store(h, w, c, y);
+      store(h, w, c, z);
+    }
+  };
+
+} // namespace oidn

+ 78 - 0
thirdparty/oidn/core/math.h

@@ -0,0 +1,78 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include "common/platform.h"
+
+namespace oidn {
+
+  constexpr float minVectorLength    = 1e-10f;
+  constexpr float minVectorLengthSqr = minVectorLength * minVectorLength;
+
+  using std::log;
+  using std::log2;
+  using std::exp;
+  using std::exp2;
+  using std::pow;
+  using std::isfinite;
+  using std::isnan;
+
+  __forceinline float sqr(float x)
+  {
+    return x * x;
+  }
+
+  __forceinline float rcp(float x)
+  {
+    __m128 r = _mm_rcp_ss(_mm_set_ss(x));
+    return _mm_cvtss_f32(_mm_sub_ss(_mm_add_ss(r, r), _mm_mul_ss(_mm_mul_ss(r, r), _mm_set_ss(x))));
+  }
+
+  __forceinline float rsqrt(float x)
+  {
+    __m128 r = _mm_rsqrt_ss(_mm_set_ss(x));
+    return _mm_cvtss_f32(_mm_add_ss(_mm_mul_ss(_mm_set_ss(1.5f), r),
+             _mm_mul_ss(_mm_mul_ss(_mm_mul_ss(_mm_set_ss(x), _mm_set_ss(-0.5f)), r), _mm_mul_ss(r, r))));
+  }
+
+  __forceinline float maxSafe(float value, float minValue)
+  {
+    return isfinite(value) ? max(value, minValue) : minValue;
+  }
+
+  __forceinline float clampSafe(float value, float minValue, float maxValue)
+  {
+    return isfinite(value) ? clamp(value, minValue, maxValue) : minValue;
+  }
+
+  // Returns ceil(a / b) for non-negative integers
+  template<class Int>
+  __forceinline constexpr Int ceilDiv(Int a, Int b)
+  {
+    //assert(a >= 0);
+    //assert(b > 0);
+    return (a + b - 1) / b;
+  }
+
+  // Returns a rounded up to multiple of b
+  template<class Int>
+  __forceinline constexpr Int roundUp(Int a, Int b)
+  {
+    return ceilDiv(a, b) * b;
+  }
+
+} // namespace oidn

+ 436 - 0
thirdparty/oidn/core/network.cpp

@@ -0,0 +1,436 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#include "upsample.h"
+#include "weights_reorder.h"
+#include "network.h"
+// -- GODOT start -- 
+#include <cstring>
+// -- GODOT end --
+
+namespace oidn {
+
+  template<int K>
+  Network<K>::Network(const Ref<Device>& device, const std::map<std::string, Tensor>& weightMap)
+    : device(device),
+      eng(engine::cpu, 0),
+      sm(eng),
+      weightMap(weightMap)
+  {
+  }
+
+  template<int K>
+  void Network<K>::execute(const Progress& progress, int taskIndex)
+  {
+    if (progress.func)
+    {
+      const double value = double(taskIndex) / double(progress.taskCount);
+      if (!progress.func(progress.userPtr, value))
+        throw Exception(Error::Cancelled, "execution was cancelled");
+    }
+
+    for (size_t i = 0; i < nodes.size(); ++i)
+    {
+      nodes[i]->execute(sm);
+
+      if (progress.func)
+      {
+        const double value = (double(taskIndex) + double(i+1) / double(nodes.size())) / double(progress.taskCount);
+        if (!progress.func(progress.userPtr, value))
+          throw Exception(Error::Cancelled, "execution was cancelled");
+      }
+    }
+  }
+
+  template<int K>
+  std::shared_ptr<memory> Network<K>::allocTensor(const memory::dims& dims,
+                                                  memory::format_tag format,
+                                                  void* data)
+  {
+    if (format == memory::format_tag::any)
+    {
+      if (dims.size() == 4)
+        format = BlockedFormat<K>::nChwKc;
+      else if (dims.size() == 1)
+        format = memory::format_tag::x;
+      else
+        assert(0);
+    }
+    memory::desc desc(dims, memory::data_type::f32, format);
+    if (data == nullptr)
+    {
+      const size_t bytes = getTensorSize(dims) * sizeof(float);
+      if (format == BlockedFormat<K>::nChwKc)
+        activationAllocBytes += bytes;
+      totalAllocBytes += bytes;
+
+      return std::make_shared<memory>(desc, eng);
+    }
+    else
+    {
+      return std::make_shared<memory>(desc, eng, data);
+    }
+  }
+
+  template<int K>
+  std::shared_ptr<memory> Network<K>::castTensor(const memory::dims& dims,
+                                                 const std::shared_ptr<memory>& src,
+                                                 size_t srcOffset,
+                                                 memory::format_tag format)
+  {
+    const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
+    MAYBE_UNUSED(srcDesc);
+    assert(srcDesc.data_type == memory::data_type::f32);
+    assert(getTensorSize(src) >= srcOffset + getTensorSize(dims));
+
+    if (format == memory::format_tag::any)
+    {
+      if (dims.size() == 4)
+        format = BlockedFormat<K>::nChwKc;
+      else if (dims.size() == 1)
+        format = memory::format_tag::x;
+      else
+        assert(0);
+    }
+    memory::desc desc(dims, memory::data_type::f32, format);
+    float* srcPtr = (float*)src->get_data_handle() + srcOffset;
+    return std::make_shared<memory>(desc, eng, srcPtr);
+  }
+
+  template<int K>
+  std::shared_ptr<memory> Network<K>::castTensor(const memory::dims& dims,
+                                                 const std::shared_ptr<memory>& src,
+                                                 const memory::dims& srcOffset)
+  {
+    return castTensor(dims, src, getTensorSize(srcOffset));
+  }
+
+  template<int K>
+  void Network<K>::zeroTensor(const std::shared_ptr<memory>& dst)
+  {
+    assert(getTensorType(dst) == memory::data_type::f32);
+    memset(dst->get_data_handle(), 0, getTensorSize(dst)*sizeof(float));
+  }
+
+  template<int K>
+  memory::dims Network<K>::getInputReorderDims(const memory::dims& srcDims, int alignment)
+  {
+    memory::dims dstDims = srcDims;
+    dstDims[1] = getPadded<K>(srcDims[1]); // round up C
+    dstDims[2] = roundUp(srcDims[2], memory::dim(alignment)); // round up H
+    dstDims[3] = roundUp(srcDims[3], memory::dim(alignment)); // round up W
+    return dstDims;
+  }
+
+  template<int K>
+  std::shared_ptr<Node> Network<K>::addInputReorder(const Image& color,
+                                                    const Image& albedo,
+                                                    const Image& normal,
+                                                    const std::shared_ptr<TransferFunction>& transferFunc,
+                                                    int alignment,
+                                                    const std::shared_ptr<memory>& userDst)
+  {
+    assert(color);
+    int inputC = 3;
+    if (albedo) inputC += 3;
+    if (normal) inputC += 3;
+
+    memory::dims srcDims = {1, inputC, color.height, color.width};
+    memory::dims dstDims = getInputReorderDims(srcDims, alignment);
+
+    // Allocate padded memory
+    auto dst = userDst;
+    if (!dst)
+      dst = allocTensor(dstDims);
+
+    // Push node
+    std::shared_ptr<Node> node;
+
+    if (auto tf = std::dynamic_pointer_cast<LinearTransferFunction>(transferFunc))
+      node = std::make_shared<InputReorderNode<K, LinearTransferFunction>>(color, albedo, normal, dst, tf);
+    else if (auto tf = std::dynamic_pointer_cast<GammaTransferFunction>(transferFunc))
+      node = std::make_shared<InputReorderNode<K, GammaTransferFunction>>(color, albedo, normal, dst, tf);
+    else if (auto tf = std::dynamic_pointer_cast<LogTransferFunction>(transferFunc))
+      node = std::make_shared<InputReorderNode<K, LogTransferFunction>>(color, albedo, normal, dst, tf);
+    else if (auto tf = std::dynamic_pointer_cast<PQXTransferFunction>(transferFunc))
+      node = std::make_shared<InputReorderNode<K, PQXTransferFunction>>(color, albedo, normal, dst, tf);
+    else
+      assert(0);
+
+    nodes.push_back(node);
+    return node;
+  }
+
+  template<int K>
+  std::shared_ptr<Node> Network<K>::addOutputReorder(const std::shared_ptr<memory>& src,
+                                                     const std::shared_ptr<TransferFunction>& transferFunc,
+                                                     const Image& output)
+  {
+    memory::dims srcDims = getTensorDims(src);
+    assert(srcDims[1] == K);
+
+    // Push node
+    std::shared_ptr<Node> node;
+
+    if (auto tf = std::dynamic_pointer_cast<LinearTransferFunction>(transferFunc))
+      node = std::make_shared<OutputReorderNode<K, LinearTransferFunction>>(src, output, tf);
+    else if (auto tf = std::dynamic_pointer_cast<GammaTransferFunction>(transferFunc))
+      node = std::make_shared<OutputReorderNode<K, GammaTransferFunction>>(src, output, tf);
+    else if (auto tf = std::dynamic_pointer_cast<LogTransferFunction>(transferFunc))
+      node = std::make_shared<OutputReorderNode<K, LogTransferFunction>>(src, output, tf);
+    else if (auto tf = std::dynamic_pointer_cast<PQXTransferFunction>(transferFunc))
+      node = std::make_shared<OutputReorderNode<K, PQXTransferFunction>>(src, output, tf);
+    else
+      assert(0);
+
+    nodes.push_back(node);
+    return node;
+  }
+
+  template<int K>
+  memory::dims Network<K>::getConvDims(const std::string& name, const memory::dims& srcDims)
+  {
+    auto b = weightMap[name + "/b"];
+    memory::dims dstDims = srcDims;
+    dstDims[1] = getPadded<K>(b.dims[0]); // dstDims[C] = getPadded(OC)
+    return dstDims;
+  }
+
+  template<int K>
+  std::shared_ptr<Node> Network<K>::addConv(const std::string& name,
+                                            const std::shared_ptr<memory>& src,
+                                            const std::shared_ptr<memory>& userDst,
+                                            bool relu)
+  {
+    const memory::dims strides = {1, 1};
+    const memory::dims padding = {1, 1};
+
+    memory::dims srcDims = getTensorDims(src);
+
+    // Get the weights
+    const auto& W = weightMap[name + "/W"];
+    if (W.ndims() != 4 || W.format != "oihw")
+      throw Exception(Error::InvalidOperation, "invalid convolution weights");
+    memory::dims weightsDims = W.dims;
+    auto userWeights = allocTensor(weightsDims, memory::format_tag::oihw, W.data);
+
+    // Pad the weights
+    memory::dims weightsPadDims = weightsDims;
+    weightsPadDims[1] = getPadded<K>(weightsDims[1]); // IC
+    weightsPadDims[0] = getPadded<K>(weightsDims[0]); // OC
+    assert(srcDims[1] == weightsPadDims[1]); // srcDims[C] == weightsPadDims[IC]
+    auto weightsPad = allocTensor(weightsPadDims, memory::format_tag::oihw);
+    WeightsReorderNode<K>(userWeights, weightsPad).execute(sm);
+
+    // Get the biases
+    const auto& b = weightMap[name + "/b"];
+    if (b.ndims() != 1)
+      throw Exception(Error::InvalidOperation, "invalid convolution biases");
+    memory::dims biasDims = b.dims;
+
+    // Copy/pad the biases
+    memory::dims biasPadDims = {getPadded<K>(biasDims[0])};
+    auto bias = allocTensor(biasPadDims);
+    if (biasDims[0] != biasPadDims[0])
+      memset(bias->get_data_handle(), 0, biasPadDims[0]*sizeof(float));
+    memcpy(bias->get_data_handle(), b.data, biasDims[0]*sizeof(float));
+
+    // Allocate memory for destination
+    memory::dims dstDims = srcDims;
+    dstDims[1] = weightsPadDims[0]; // dstDims[C] = weightsPadDims[OC]
+
+    std::shared_ptr<memory> dst;
+    if (!userDst)
+      dst = allocTensor(dstDims);
+    else if (getTensorDims(userDst) == dstDims)
+      dst = userDst;
+    else
+      dst = castTensor(dstDims, userDst);
+
+    // Create a convolution
+    // Let the convolution primitive choose the weights format
+    auto weightsDesc = memory::desc({ weightsPadDims }, memory::data_type::f32, memory::format_tag::any);
+
+    auto convAlgo = (K == 16) ? convolution_winograd : convolution_direct;
+    auto convDesc = convolution_forward::desc(
+      prop_kind::forward_inference, convAlgo,
+      src->get_desc(),
+      weightsDesc,
+      bias->get_desc(),
+      dst->get_desc(),
+      strides, padding, padding, padding_kind::zero);
+
+    // Incorporate relu
+    mkldnn::primitive_attr convAttr;
+    if (relu)
+    {
+      mkldnn::post_ops ops;
+      ops.append_eltwise(
+        1.f,   // scale factor, not used
+        algorithm::eltwise_relu,
+        0.f,   // max with
+        0.f    // unused
+      );
+      convAttr.set_post_ops(ops);
+    }
+    convAttr.set_scratchpad_mode(scratchpad_mode_user);
+
+    auto convPrimDesc = convolution_forward::primitive_desc(convDesc, convAttr, eng);
+
+    // Reorder the weights to the final format, if necessary
+    auto weights = weightsPad;
+    if (convPrimDesc.weights_desc() != weightsPad->get_desc())
+    {
+      weights = std::make_shared<memory>(convPrimDesc.weights_desc(), eng);
+      ReorderNode(weightsPad, weights).execute(sm);
+    }
+
+    // Create convolution node and add it to the net
+    auto node = std::make_shared<ConvNode>(convPrimDesc, src, weights, bias, dst);
+    nodes.push_back(node);
+    return node;
+  }
+
+  template<int K>
+  memory::dims Network<K>::getPoolDims(const memory::dims& srcDims)
+  {
+    memory::dims dstDims = srcDims;
+    dstDims[2] /= 2; // H/2
+    dstDims[3] /= 2; // W/2
+    return dstDims;
+  }
+
+  template<int K>
+  std::shared_ptr<Node> Network<K>::addPool(const std::shared_ptr<memory>& src,
+                                            const std::shared_ptr<memory>& userDst)
+  {
+    const memory::dims kernel  = {2, 2};
+    const memory::dims strides = {2, 2};
+    const memory::dims padding = {0, 0};
+
+    memory::dims srcDims = getTensorDims(src);
+    memory::dims dstDims = getPoolDims(srcDims);
+
+    std::shared_ptr<memory> dst;
+    if (!userDst)
+      dst = allocTensor(dstDims);
+    else if (getTensorDims(userDst) == dstDims)
+      dst = userDst;
+    else
+      dst = castTensor(dstDims, userDst);
+
+    auto poolDesc = pooling_forward::desc(
+      prop_kind::forward_inference, pooling_max,
+      src->get_desc(),
+      dst->get_desc(),
+      strides, kernel, padding, padding, padding_kind::zero);
+
+    mkldnn::primitive_attr poolAttr;
+    poolAttr.set_scratchpad_mode(scratchpad_mode_user);
+
+    auto poolPrimDesc = pooling_forward::primitive_desc(poolDesc, poolAttr, eng);
+
+    auto node = std::make_shared<PoolNode>(poolPrimDesc, src, dst);
+    nodes.push_back(node);
+    return node;
+  }
+
+  template<int K>
+  memory::dims Network<K>::getUpsampleDims(const memory::dims& srcDims)
+  {
+    memory::dims dstDims = srcDims;
+    dstDims[2] *= 2; // H*2
+    dstDims[3] *= 2; // W*2
+    return dstDims;
+  }
+
+  template<int K>
+  std::shared_ptr<Node> Network<K>::addUpsample(const std::shared_ptr<memory>& src,
+                                                const std::shared_ptr<memory>& userDst)
+  {
+    memory::dims srcDims = getTensorDims(src);
+    memory::dims dstDims = getUpsampleDims(srcDims);
+
+    std::shared_ptr<memory> dst;
+    if (!userDst)
+      dst = allocTensor(dstDims);
+    else if (getTensorDims(userDst) == dstDims)
+      dst = userDst;
+    else
+      dst = castTensor(dstDims, userDst);
+
+    // Create upsampling node and add it to net
+    auto node = std::make_shared<UpsampleNode<K>>(src, dst);
+    nodes.push_back(node);
+    return node;
+  }
+
+  template<int K>
+  memory::dims Network<K>::getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims)
+  {
+    assert(src1Dims[0] == src2Dims[0]); // N
+    assert(src1Dims[2] == src2Dims[2]); // H
+    assert(src1Dims[3] == src2Dims[3]); // W
+
+    memory::dims dstDims = src1Dims;
+    dstDims[1] += src2Dims[1]; // C
+    return dstDims;
+  }
+
+  template<int K>
+  std::shared_ptr<Node> Network<K>::addAutoexposure(const Image& color,
+                                                    const std::shared_ptr<HDRTransferFunction>& transferFunc)
+  {
+    auto node = std::make_shared<AutoexposureNode>(color, transferFunc);
+    nodes.push_back(node);
+    return node;
+  }
+
+  template <int K>
+  void Network<K>::finalize()
+  {
+    // Compute the size of the scratchpad
+    size_t scratchpadSize = 0;
+    for (const auto& node : nodes)
+      scratchpadSize = max(scratchpadSize, node->getScratchpadSize());
+
+    // Allocate the scratchpad
+    memory::dims scratchpadDims = { memory::dim(scratchpadSize) };
+    memory::desc scratchpadDesc(scratchpadDims, memory::data_type::u8, memory::format_tag::x);
+    auto scratchpad = std::make_shared<memory>(scratchpadDesc, eng);
+    activationAllocBytes += scratchpadSize;
+    totalAllocBytes += scratchpadSize;
+
+    // Set the scratchpad for the nodes
+    for (auto& node : nodes)
+      node->setScratchpad(scratchpad);
+
+    // Free the weights
+    weightMap.clear();
+
+    // Print statistics
+    if (device->isVerbose(2))
+    {
+      std::cout << "Activation bytes: " << activationAllocBytes << std::endl;
+      std::cout << "Scratchpad bytes: " << scratchpadSize << std::endl;
+      std::cout << "Total bytes     : " << totalAllocBytes << std::endl;
+    }
+  }
+
+  template class Network<8>;
+  template class Network<16>;
+
+} // namespace oidn

+ 112 - 0
thirdparty/oidn/core/network.h

@@ -0,0 +1,112 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#include "common/tensor.h"
+#include "image.h"
+#include "node.h"
+#include "input_reorder.h"
+#include "output_reorder.h"
+#include "transfer_function.h"
+
+#pragma once
+
+namespace oidn {
+
+  // Progress state
+  struct Progress
+  {
+    ProgressMonitorFunction func;
+    void* userPtr;
+    int taskCount;
+  };
+
+  class Executable
+  {
+  public:
+    virtual ~Executable() {}
+    virtual void execute(const Progress& progress, int taskIndex) = 0;
+  };
+
+  template<int K>
+  class Network : public Executable
+  {
+  public:
+    Network(const Ref<Device>& device, const std::map<std::string, Tensor>& weightMap);
+
+    void execute(const Progress& progress, int taskIndex) override;
+
+    std::shared_ptr<memory> allocTensor(const memory::dims& dims,
+                                        memory::format_tag format = memory::format_tag::any,
+                                        void* data = nullptr);
+
+    std::shared_ptr<memory> castTensor(const memory::dims& dims,
+                                       const std::shared_ptr<memory>& src,
+                                       size_t srcOffset = 0,
+                                       memory::format_tag format = memory::format_tag::any);
+
+    std::shared_ptr<memory> castTensor(const memory::dims& dims,
+                                       const std::shared_ptr<memory>& src,
+                                       const memory::dims& srcOffset);
+
+    void zeroTensor(const std::shared_ptr<memory>& dst);
+
+    memory::dims getInputReorderDims(const memory::dims& srcDims, int alignment);
+
+    std::shared_ptr<Node> addInputReorder(const Image& color,
+                                          const Image& albedo,
+                                          const Image& normal,
+                                          const std::shared_ptr<TransferFunction>& transferFunc,
+                                          int alignment,
+                                          const std::shared_ptr<memory>& userDst = nullptr);
+
+    std::shared_ptr<Node> addOutputReorder(const std::shared_ptr<memory>& src,
+                                           const std::shared_ptr<TransferFunction>& transferFunc,
+                                           const Image& output);
+
+    memory::dims getConvDims(const std::string& name, const memory::dims& srcDims);
+    std::shared_ptr<Node> addConv(const std::string& name,
+                                  const std::shared_ptr<memory>& src,
+                                  const std::shared_ptr<memory>& userDst = nullptr,
+                                  bool relu = true);
+
+    memory::dims getPoolDims(const memory::dims& srcDims);
+    std::shared_ptr<Node> addPool(const std::shared_ptr<memory>& src,
+                                  const std::shared_ptr<memory>& userDst = nullptr);
+
+    memory::dims getUpsampleDims(const memory::dims& srcDims);
+    std::shared_ptr<Node> addUpsample(const std::shared_ptr<memory>& src,
+                                      const std::shared_ptr<memory>& userDst = nullptr);
+
+    memory::dims getConcatDims(const memory::dims& src1Dims, const memory::dims& src2Dims);
+
+    std::shared_ptr<Node> addAutoexposure(const Image& color,
+                                          const std::shared_ptr<HDRTransferFunction>& transferFunc);
+
+    void finalize();
+
+  private:
+    Ref<Device> device;
+    engine eng;
+    stream sm;
+    std::vector<std::shared_ptr<Node>> nodes;
+    std::map<std::string, Tensor> weightMap;
+
+    // Memory allocation statistics
+    size_t activationAllocBytes = 0; // number of allocated activation bytes
+    size_t totalAllocBytes      = 0; // total number of allocated bytes
+  };
+
+} // namespace oidn

+ 142 - 0
thirdparty/oidn/core/node.h

@@ -0,0 +1,142 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include "common.h"
+#include <vector>
+
+namespace oidn {
+
+  class Node
+  {
+  public:
+    virtual ~Node() = default;
+
+    virtual void execute(stream& sm) = 0;
+
+    virtual std::shared_ptr<memory> getDst() const { return nullptr; }
+
+    virtual size_t getScratchpadSize() const { return 0; }
+    virtual void setScratchpad(const std::shared_ptr<memory>& mem) {}
+
+    virtual void setTile(int h1, int w1, int h2, int w2, int H, int W)
+    {
+      assert(0); // not supported
+    }
+  };
+
+  // Node wrapping an MKL-DNN primitive
+  class MklNode : public Node
+  {
+  private:
+    primitive prim;
+    std::unordered_map<int, memory> args;
+    std::shared_ptr<memory> scratchpad;
+
+  public:
+    MklNode(const primitive& prim, const std::unordered_map<int, memory>& args)
+      : prim(prim),
+        args(args)
+    {}
+
+    size_t getScratchpadSize() const override
+    {
+      const auto primDesc = prim.get_primitive_desc();
+      const mkldnn_memory_desc_t* scratchpadDesc = mkldnn_primitive_desc_query_md(primDesc, mkldnn_query_scratchpad_md, 0);
+      if (scratchpadDesc == nullptr)
+        return 0;
+      return mkldnn_memory_desc_get_size(scratchpadDesc);
+    }
+
+    void setScratchpad(const std::shared_ptr<memory>& mem) override
+    {
+      scratchpad = mem;
+      args.insert(std::make_pair(MKLDNN_ARG_SCRATCHPAD, *scratchpad));
+    }
+
+    void execute(stream& sm) override
+    {
+      prim.execute(sm, args);
+    }
+  };
+
+  // Convolution node
+  class ConvNode : public MklNode
+  {
+  private:
+    std::shared_ptr<memory> src;
+    std::shared_ptr<memory> weights;
+    std::shared_ptr<memory> bias;
+    std::shared_ptr<memory> dst;
+
+  public:
+    ConvNode(const convolution_forward::primitive_desc& desc,
+             const std::shared_ptr<memory>& src,
+             const std::shared_ptr<memory>& weights,
+             const std::shared_ptr<memory>& bias,
+             const std::shared_ptr<memory>& dst)
+      : MklNode(convolution_forward(desc),
+                { { MKLDNN_ARG_SRC, *src },
+                  { MKLDNN_ARG_WEIGHTS, *weights },
+                  { MKLDNN_ARG_BIAS, *bias },
+                  { MKLDNN_ARG_DST, *dst } }),
+                src(src), weights(weights), bias(bias), dst(dst)
+    {}
+
+    std::shared_ptr<memory> getDst() const override { return dst; }
+  };
+
+  // Pooling node
+  class PoolNode : public MklNode
+  {
+  private:
+    std::shared_ptr<memory> src;
+    std::shared_ptr<memory> dst;
+
+  public:
+    PoolNode(const pooling_forward::primitive_desc& desc,
+             const std::shared_ptr<memory>& src,
+             const std::shared_ptr<memory>& dst)
+      : MklNode(pooling_forward(desc),
+                { { MKLDNN_ARG_SRC, *src },
+                  { MKLDNN_ARG_DST, *dst } }),
+                src(src), dst(dst)
+    {}
+
+    std::shared_ptr<memory> getDst() const override { return dst; }
+  };
+
+  // Reorder node
+  class ReorderNode : public MklNode
+  {
+  private:
+    std::shared_ptr<memory> src;
+    std::shared_ptr<memory> dst;
+
+  public:
+    ReorderNode(const std::shared_ptr<memory>& src,
+                const std::shared_ptr<memory>& dst)
+      : MklNode(reorder(reorder::primitive_desc(*src, *dst)),
+                { { MKLDNN_ARG_SRC, *src },
+                  { MKLDNN_ARG_DST, *dst } }),
+                src(src), dst(dst)
+    {}
+
+    std::shared_ptr<memory> getDst() const override { return dst; }
+  };
+
+} // namespace oidn

+ 126 - 0
thirdparty/oidn/core/output_reorder.h

@@ -0,0 +1,126 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include "node.h"
+#include "image.h"
+
+namespace oidn {
+
+  // Output reorder node
+  template<int K, class TransferFunction>
+  class OutputReorderNode : public Node
+  {
+  private:
+    // Source
+    std::shared_ptr<memory> src;
+    const float* srcPtr;
+    int H1;
+    int W1;
+
+    // Destination
+    Image output;
+
+    // Tile
+    int h1Begin;
+    int w1Begin;
+    int h2Begin;
+    int w2Begin;
+    int H;
+    int W;
+
+    std::shared_ptr<TransferFunction> transferFunc;
+
+  public:
+    OutputReorderNode(const std::shared_ptr<memory>& src,
+                      const Image& output,
+                      const std::shared_ptr<TransferFunction>& transferFunc)
+      : src(src),
+        output(output),
+        h1Begin(0), w1Begin(0),
+        h2Begin(0), w2Begin(0),
+        H(output.height), W(output.width),
+        transferFunc(transferFunc)
+    {
+      const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
+      MAYBE_UNUSED(srcDesc);
+      assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
+      assert(srcDesc.ndims == 4);
+      assert(srcDesc.data_type == memory::data_type::f32);
+      assert(srcDesc.dims[0] == 1);
+      // We assume output data is <= K OC
+      assert(srcDesc.dims[1] == K);
+
+      srcPtr = (float*)src->get_data_handle();
+      H1 = srcDesc.dims[2];
+      W1 = srcDesc.dims[3];
+    }
+
+    void setTile(int h1, int w1, int h2, int w2, int H, int W) override
+    {
+      h1Begin = h1;
+      w1Begin = w1;
+      h2Begin = h2;
+      w2Begin = w2;
+      this->H = H;
+      this->W = W;
+    }
+
+    void execute(stream& sm) override
+    {
+      assert(h1Begin + H <= H1);
+      assert(w1Begin + W <= W1);
+      assert(h2Begin + H <= output.height);
+      assert(w2Begin + W <= output.width);
+
+      const int C1 = K;
+
+      parallel_nd(H, [&](int h)
+      {
+        const int h1 = h + h1Begin;
+        const int h2 = h + h2Begin;
+
+        for (int w = 0; w < W; ++w)
+        {
+          const int w1 = w + w1Begin;
+          const int w2 = w + w2Begin;
+          float* dstPtr_C = (float*)output.get(h2, w2);
+
+          // Source is in nChwKc format. In this case C is 1 so this is really nhwc
+          const float* srcPtr_C = srcPtr + h1*W1*C1 + w1*C1;
+
+          #pragma unroll
+          for (int i = 0; i < 3; ++i)
+          {
+            // Load the value
+            float x = srcPtr_C[i];
+
+            // The CNN output may contain negative values or even NaNs, so it must be sanitized
+            x = maxSafe(x, 0.f);
+
+            // Apply the inverse transfer function
+            x = transferFunc->inverse(x);
+
+            // Sanitize and store the final value
+            dstPtr_C[i] = max(x, 0.f);
+          }
+        }
+      });
+    }
+  };
+
+} // namespace oidn

+ 103 - 0
thirdparty/oidn/core/transfer_function.cpp

@@ -0,0 +1,103 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#include "transfer_function.h"
+
+namespace oidn {
+
+  const float LogTransferFunction::xScale = 1.f / log(LogTransferFunction::yMax + 1.f);
+  const float PQXTransferFunction::xScale = 1.f / PQXTransferFunction::pqxForward(PQXTransferFunction::yMax * PQXTransferFunction::yScale);
+
+  float AutoexposureNode::autoexposure(const Image& color)
+  {
+    assert(color.format == Format::Float3);
+
+    constexpr float key = 0.18f;
+    constexpr float eps = 1e-8f;
+    constexpr int K = 16; // downsampling amount
+
+    // Downsample the image to minimize sensitivity to noise
+    const int H  = color.height;  // original height
+    const int W  = color.width;   // original width
+    const int HK = (H + K/2) / K; // downsampled height
+    const int WK = (W + K/2) / K; // downsampled width
+
+    // Compute the average log luminance of the downsampled image
+    using Sum = std::pair<float, int>;
+
+    // -- GODOT start --
+    // Sum sum =
+    //   tbb::parallel_reduce(
+    //     tbb::blocked_range2d<int>(0, HK, 0, WK),
+    //     Sum(0.f, 0),
+    //     [&](const tbb::blocked_range2d<int>& r, Sum sum) -> Sum
+    //     {
+    //       // Iterate over blocks
+    //       for (int i = r.rows().begin(); i != r.rows().end(); ++i)
+    //       {
+    //         for (int j = r.cols().begin(); j != r.cols().end(); ++j)
+    //         {
+
+    Sum sum = Sum(0.0f, 0);
+
+    for (int i = 0; i != HK; ++i)
+    {
+      for (int j = 0; j != WK; ++j)
+      {
+        // Compute the average luminance in the current block
+        const int beginH = int(ptrdiff_t(i)   * H / HK);
+        const int beginW = int(ptrdiff_t(j)   * W / WK);
+        const int endH   = int(ptrdiff_t(i+1) * H / HK);
+        const int endW   = int(ptrdiff_t(j+1) * W / WK);
+
+        float L = 0.f;
+
+        for (int h = beginH; h < endH; ++h)
+        {
+          for (int w = beginW; w < endW; ++w)
+          {
+            const float* rgb = (const float*)color.get(h, w);
+
+            const float r = maxSafe(rgb[0], 0.f);
+            const float g = maxSafe(rgb[1], 0.f);
+            const float b = maxSafe(rgb[2], 0.f);
+
+            L += luminance(r, g, b);
+          }
+        }
+
+        L /= (endH - beginH) * (endW - beginW);
+
+        // Accumulate the log luminance
+        if (L > eps)
+        {
+          sum.first += log2(L);
+          sum.second++;
+        }
+      }
+    }
+
+    //     return sum;
+    //   },
+    //   [](Sum a, Sum b) -> Sum { return Sum(a.first+b.first, a.second+b.second); },
+    //   tbb::static_partitioner()
+    // );
+    // -- GODOT end --
+
+    return (sum.second > 0) ? (key / exp2(sum.first / float(sum.second))) : 1.f;
+  }
+
+} // namespace oidn

+ 201 - 0
thirdparty/oidn/core/transfer_function.h

@@ -0,0 +1,201 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include "image.h"
+#include "node.h"
+
+namespace oidn {
+
+  __forceinline float luminance(float r, float g, float b)
+  {
+    return 0.212671f * r + 0.715160f * g + 0.072169f * b;
+  }
+
+  // Color transfer function base class
+  class TransferFunction
+  {
+  public:
+    virtual ~TransferFunction() = default;
+
+    virtual float forward(float y) const = 0;
+    virtual float inverse(float x) const = 0;
+  };
+
+  // HDR transfer function base class
+  class HDRTransferFunction : public TransferFunction
+  {
+  protected:
+    static constexpr float yMax = 65504.f;
+
+    float exposure;
+    float rcpExposure;
+
+  public:
+    HDRTransferFunction(float exposure = 1.f)
+    {
+      setExposure(exposure);
+    }
+
+    void setExposure(float exposure)
+    {
+      this->exposure = exposure;
+      this->rcpExposure = (exposure != 0.f) ? (1.f / exposure) : 0.f;
+    }
+  };
+
+  // Linear transfer function (LDR)
+  class LinearTransferFunction : public TransferFunction
+  {
+  public:
+    __forceinline float forward(float y) const override
+    {
+      return min(y, 1.f);
+    }
+
+    __forceinline float inverse(float x) const override
+    {
+      return min(x, 1.f);
+    }
+  };
+
+  // 2.2 gamma transfer function (LDR)
+  class GammaTransferFunction : public TransferFunction
+  {
+  public:
+    __forceinline float forward(float y) const override
+    {
+      return min(pow(y, 1.f/2.2f), 1.f);
+    }
+
+    __forceinline float inverse(float x) const override
+    {
+      return min(pow(x, 2.2f), 1.f);
+    }
+  };
+
+  // Logarithmic transfer function (HDR)
+  // Compresses [0..65504] to [0..1]
+  class LogTransferFunction : public HDRTransferFunction
+  {
+  private:
+    static const float xScale;
+
+  public:
+    LogTransferFunction(float exposure = 1.f)
+      : HDRTransferFunction(exposure)
+    {
+    }
+
+    __forceinline float forward(float y) const override
+    {
+      return log(y * exposure + 1.f) * xScale;
+    }
+
+    __forceinline float inverse(float x) const override
+    {
+      return (exp(x * (1.f/xScale)) - 1.f) * rcpExposure;
+    }
+  };
+
+  // PQX transfer function (HDR)
+  // Compresses [0..65504] to [0..1]
+  class PQXTransferFunction : public HDRTransferFunction
+  {
+  private:
+    static constexpr float m1 = 2610.f / 4096.f / 4.f;
+    static constexpr float m2 = 2523.f / 4096.f * 128.f;
+    static constexpr float c1 = 3424.f / 4096.f;
+    static constexpr float c2 = 2413.f / 4096.f * 32.f;
+    static constexpr float c3 = 2392.f / 4096.f * 32.f;
+    static constexpr float  a = 3711.f / 4096.f / 8.f;
+
+    static constexpr float yScale = 100.f / 10000.f;
+    static const float     xScale;
+
+  public:
+    PQXTransferFunction(float exposure = 1.f)
+      : HDRTransferFunction(exposure)
+    {
+    }
+
+    __forceinline float forward(float y) const override
+    {
+      return pqxForward(y * exposure * yScale) * xScale;
+    }
+
+    __forceinline float inverse(float x) const override
+    {
+      return pqxInverse(x * (1.f/xScale)) * (1.f/yScale) * rcpExposure;
+    }
+
+  private:
+    static __forceinline float pqForward(float y)
+    {
+      const float yp = pow(y, m1);
+      return pow((c1 + c2 * yp) * rcp(1.f + c3 * yp), m2);
+    }
+
+    static __forceinline float pqxForward(float y)
+    {
+      if (y <= 1.f)
+        return pqForward(y);
+      else
+        return a * log(y) + 1.f;
+    }
+
+    static __forceinline float pqInverse(float x)
+    {
+      const float xp = pow(x, 1.f/m2);
+      return pow(max((xp - c1) * rcp(c2 - c3 * xp), 0.f), 1.f/m1);
+    }
+
+    static __forceinline float pqxInverse(float x)
+    {
+      if (x <= 1.f)
+        return pqInverse(x);
+      else
+        return exp((x - 1.f) * (1.f/a));
+    }
+  };
+
+  // Autoexposure node
+  class AutoexposureNode : public Node
+  {
+  private:
+    Image color;
+    std::shared_ptr<HDRTransferFunction> transferFunc;
+
+  public:
+    AutoexposureNode(const Image& color,
+                     const std::shared_ptr<HDRTransferFunction>& transferFunc)
+      : color(color),
+        transferFunc(transferFunc)
+    {}
+
+    void execute(stream& sm) override
+    {
+      const float exposure = autoexposure(color);
+      //printf("exposure = %f\n", exposure);
+      transferFunc->setExposure(exposure);
+    }
+
+  private:
+    static float autoexposure(const Image& color);
+  };
+
+} // namespace oidn

+ 92 - 0
thirdparty/oidn/core/upsample.h

@@ -0,0 +1,92 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include "node.h"
+
+namespace oidn {
+
+  // 2x2 nearest-neighbor upsampling node
+  template<int K>
+  class UpsampleNode : public Node
+  {
+  private:
+    std::shared_ptr<memory> src;
+    std::shared_ptr<memory> dst;
+
+  public:
+    UpsampleNode(const std::shared_ptr<memory>& src,
+                 const std::shared_ptr<memory>& dst)
+      : src(src),
+        dst(dst)
+    {
+      const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
+      const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
+      MAYBE_UNUSED(srcDesc);
+      MAYBE_UNUSED(dstDesc);
+      assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
+      assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(BlockedFormat<K>::nChwKc)));
+      assert(srcDesc.ndims == 4);
+      assert(dstDesc.ndims == 4);
+      assert(srcDesc.data_type == memory::data_type::f32);
+      assert(dstDesc.data_type == memory::data_type::f32);
+      assert(srcDesc.dims[0] == 1);
+      assert(dstDesc.dims[0] == 1);
+      // 2x2 upsampling
+      assert(dstDesc.dims[2] == srcDesc.dims[2] * 2);
+      assert(dstDesc.dims[3] == srcDesc.dims[3] * 2);
+    }
+
+    void execute(stream& sm) override
+    {
+      const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
+
+      const float* srcPtr = (float*)src->get_data_handle();
+      float* dstPtr = (float*)dst->get_data_handle();
+
+      const int C = srcDesc.dims[1];
+      const int H = srcDesc.dims[2];
+      const int W = srcDesc.dims[3];
+      const int CK = C / K;
+
+      parallel_nd(CK, H, [&](int ck, int h)
+      {
+        const size_t offset = ck*H*W*K + h*W*K;
+        const float* srcPtr_line = srcPtr + offset;
+        float* dstPtr_line0 = dstPtr + offset * 4;
+        float* dstPtr_line1 = dstPtr_line0 + W*2*K; // next line
+
+        for (int w = 0; w < W; ++w)
+        {
+          #pragma unroll
+          for (int k = 0; k < K; k += 4)
+          {
+            const __m128 m = _mm_load_ps(&srcPtr_line[w*K + k]);
+
+            _mm_stream_ps(&dstPtr_line0[w*2*K   + k], m);
+            _mm_stream_ps(&dstPtr_line0[w*2*K+K + k], m);
+            _mm_stream_ps(&dstPtr_line1[w*2*K   + k], m);
+            _mm_stream_ps(&dstPtr_line1[w*2*K+K + k], m);
+          }
+        }
+      });
+    }
+
+    std::shared_ptr<memory> getDst() const override { return dst; }
+  };
+
+} // namespace oidn

+ 99 - 0
thirdparty/oidn/core/weights_reorder.h

@@ -0,0 +1,99 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include "node.h"
+
+namespace oidn {
+
+  // Reorders weights from oihw to padded oihw format
+  template<int K>
+  class WeightsReorderNode : public Node
+  {
+  private:
+    std::shared_ptr<memory> src;
+    std::shared_ptr<memory> dst;
+
+  public:
+    WeightsReorderNode(const std::shared_ptr<memory>& src,
+                       const std::shared_ptr<memory>& dst)
+      : src(src),
+        dst(dst)
+    {
+      const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
+      const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
+      MAYBE_UNUSED(srcDesc);
+      MAYBE_UNUSED(dstDesc);
+      assert(memory_desc_matches_tag(srcDesc, mkldnn_format_tag_t(memory::format_tag::oihw)));
+      assert(memory_desc_matches_tag(dstDesc, mkldnn_format_tag_t(memory::format_tag::oihw)));
+      assert(srcDesc.ndims == 4);
+      assert(dstDesc.ndims == 4);
+      assert(srcDesc.data_type == memory::data_type::f32);
+      assert(dstDesc.data_type == memory::data_type::f32);
+      assert(getPadded<K>(srcDesc.dims[0]) == dstDesc.dims[0]); // OC
+      assert(getPadded<K>(srcDesc.dims[1]) == dstDesc.dims[1]); // IC
+      assert(srcDesc.dims[2] == dstDesc.dims[2]);
+      assert(srcDesc.dims[3] == dstDesc.dims[3]);
+    }
+
+    void execute(stream& sm) override
+    {
+      const mkldnn_memory_desc_t& srcDesc = src->get_desc().data;
+      const mkldnn_memory_desc_t& dstDesc = dst->get_desc().data;
+
+      const float* srcPtr = (float*)src->get_data_handle();
+      float* dstPtr = (float*)dst->get_data_handle();
+
+      const int OC1 = srcDesc.dims[0];
+      const int OC2 = dstDesc.dims[0];
+      const int IC1 = srcDesc.dims[1];
+      const int IC2 = dstDesc.dims[1];
+      const int H   = dstDesc.dims[2];
+      const int W   = dstDesc.dims[3];
+
+      for (int oc = 0; oc < OC2; ++oc)
+      {
+        for (int ic = 0; ic < IC2; ++ic)
+        {
+          for (int h = 0; h < H; ++h)
+          {
+            for (int w = 0; w < W; ++w)
+            {
+              // Output is in oihw format
+              float* dstPtr_c = dstPtr + oc*IC2*H*W + ic*H*W + h*W + w;
+
+              if (oc < OC1 && ic < IC1)
+              {
+                // Input is in oihw format
+                const float* srcPtr_c = srcPtr + oc*IC1*H*W + ic*H*W + h*W + w;
+                *dstPtr_c = *srcPtr_c;
+              }
+              else
+              {
+                // padding
+                *dstPtr_c = 0;
+              }
+            }
+          }
+        }
+      }
+    }
+
+    std::shared_ptr<memory> getDst() const override { return dst; }
+  };
+
+} // namespace oidn

+ 214 - 0
thirdparty/oidn/include/OpenImageDenoise/oidn.h

@@ -0,0 +1,214 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include <stddef.h>
+#include <stdbool.h>
+#include <stdint.h>
+
+#include "version.h"
+
+#if defined(__cplusplus)
+extern "C" {
+#endif
+
+#ifndef OIDN_API
+#if defined(_WIN32) && !defined(OIDN_STATIC_LIB)
+#  define OIDN_API __declspec(dllimport)
+#else
+#  define OIDN_API
+#endif
+#endif
+
+// ----------------------------------------------------------------------------
+// Device
+// ----------------------------------------------------------------------------
+
+// Device types
+typedef enum
+{
+  OIDN_DEVICE_TYPE_DEFAULT = 0, // select device automatically
+
+  OIDN_DEVICE_TYPE_CPU = 1, // CPU device
+} OIDNDeviceType;
+
+// Error codes
+typedef enum
+{
+  OIDN_ERROR_NONE                 = 0, // no error occurred
+  OIDN_ERROR_UNKNOWN              = 1, // an unknown error occurred
+  OIDN_ERROR_INVALID_ARGUMENT     = 2, // an invalid argument was specified
+  OIDN_ERROR_INVALID_OPERATION    = 3, // the operation is not allowed
+  OIDN_ERROR_OUT_OF_MEMORY        = 4, // not enough memory to execute the operation
+  OIDN_ERROR_UNSUPPORTED_HARDWARE = 5, // the hardware (e.g. CPU) is not supported
+  OIDN_ERROR_CANCELLED            = 6, // the operation was cancelled by the user
+} OIDNError;
+
+// Error callback function
+typedef void (*OIDNErrorFunction)(void* userPtr, OIDNError code, const char* message);
+
+// Device handle
+typedef struct OIDNDeviceImpl* OIDNDevice;
+
+// Creates a new device.
+OIDN_API OIDNDevice oidnNewDevice(OIDNDeviceType type);
+
+// Retains the device (increments the reference count).
+OIDN_API void oidnRetainDevice(OIDNDevice device);
+
+// Releases the device (decrements the reference count).
+OIDN_API void oidnReleaseDevice(OIDNDevice device);
+
+// Sets a boolean parameter of the device.
+OIDN_API void oidnSetDevice1b(OIDNDevice device, const char* name, bool value);
+
+// Sets an integer parameter of the device.
+OIDN_API void oidnSetDevice1i(OIDNDevice device, const char* name, int value);
+
+// Gets a boolean parameter of the device.
+OIDN_API bool oidnGetDevice1b(OIDNDevice device, const char* name);
+
+// Gets an integer parameter of the device (e.g. "version").
+OIDN_API int oidnGetDevice1i(OIDNDevice device, const char* name);
+
+// Sets the error callback function of the device.
+OIDN_API void oidnSetDeviceErrorFunction(OIDNDevice device, OIDNErrorFunction func, void* userPtr);
+
+// Returns the first unqueried error code stored in the device for the current
+// thread, optionally also returning a string message (if not NULL), and clears
+// the stored error. Can be called with a NULL device as well to check why a
+// device creation failed.
+OIDN_API OIDNError oidnGetDeviceError(OIDNDevice device, const char** outMessage);
+
+// Commits all previous changes to the device.
+// Must be called before first using the device (e.g. creating filters).
+OIDN_API void oidnCommitDevice(OIDNDevice device);
+
+// ----------------------------------------------------------------------------
+// Buffer
+// ----------------------------------------------------------------------------
+
+// Formats for images and other data stored in buffers
+typedef enum
+{
+  OIDN_FORMAT_UNDEFINED = 0,
+
+  // 32-bit single-precision floating point scalar and vector formats
+  OIDN_FORMAT_FLOAT  = 1,
+  OIDN_FORMAT_FLOAT2 = 2,
+  OIDN_FORMAT_FLOAT3 = 3,
+  OIDN_FORMAT_FLOAT4 = 4,
+} OIDNFormat;
+
+// Access modes for mapping buffers
+typedef enum
+{
+  OIDN_ACCESS_READ          = 0, // read-only access
+  OIDN_ACCESS_WRITE         = 1, // write-only access
+  OIDN_ACCESS_READ_WRITE    = 2, // read and write access
+  OIDN_ACCESS_WRITE_DISCARD = 3, // write-only access, previous contents discarded
+} OIDNAccess;
+
+// Buffer handle
+typedef struct OIDNBufferImpl* OIDNBuffer;
+
+// Creates a new buffer (data allocated and owned by the device).
+OIDN_API OIDNBuffer oidnNewBuffer(OIDNDevice device, size_t byteSize);
+
+// Creates a new shared buffer (data allocated and owned by the user).
+OIDN_API OIDNBuffer oidnNewSharedBuffer(OIDNDevice device, void* ptr, size_t byteSize);
+
+// Maps a region of the buffer to host memory.
+// If byteSize is 0, the maximum available amount of memory will be mapped.
+OIDN_API void* oidnMapBuffer(OIDNBuffer buffer, OIDNAccess access, size_t byteOffset, size_t byteSize);
+
+// Unmaps a region of the buffer.
+// mappedPtr must be a pointer returned by a previous call to oidnMapBuffer.
+OIDN_API void oidnUnmapBuffer(OIDNBuffer buffer, void* mappedPtr);
+
+// Retains the buffer (increments the reference count).
+OIDN_API void oidnRetainBuffer(OIDNBuffer buffer);
+
+// Releases the buffer (decrements the reference count).
+OIDN_API void oidnReleaseBuffer(OIDNBuffer buffer);
+
+// ----------------------------------------------------------------------------
+// Filter
+// ----------------------------------------------------------------------------
+
+// Progress monitor callback function
+typedef bool (*OIDNProgressMonitorFunction)(void* userPtr, double n);
+
+// Filter handle
+typedef struct OIDNFilterImpl* OIDNFilter;
+
+// Creates a new filter of the specified type (e.g. "RT").
+OIDN_API OIDNFilter oidnNewFilter(OIDNDevice device, const char* type);
+
+// Retains the filter (increments the reference count).
+OIDN_API void oidnRetainFilter(OIDNFilter filter);
+
+// Releases the filter (decrements the reference count).
+OIDN_API void oidnReleaseFilter(OIDNFilter filter);
+
+// Sets an image parameter of the filter (stored in a buffer).
+// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically.
+OIDN_API void oidnSetFilterImage(OIDNFilter filter, const char* name,
+                                 OIDNBuffer buffer, OIDNFormat format,
+                                 size_t width, size_t height,
+                                 size_t byteOffset,
+                                 size_t bytePixelStride, size_t byteRowStride);
+
+// Sets an image parameter of the filter (owned by the user).
+// If bytePixelStride and/or byteRowStride are zero, these will be computed automatically.
+OIDN_API void oidnSetSharedFilterImage(OIDNFilter filter, const char* name,
+                                       void* ptr, OIDNFormat format,
+                                       size_t width, size_t height,
+                                       size_t byteOffset,
+                                       size_t bytePixelStride, size_t byteRowStride);
+
+// Sets a boolean parameter of the filter.
+OIDN_API void oidnSetFilter1b(OIDNFilter filter, const char* name, bool value);
+
+// Gets a boolean parameter of the filter.
+OIDN_API bool oidnGetFilter1b(OIDNFilter filter, const char* name);
+
+// Sets an integer parameter of the filter.
+OIDN_API void oidnSetFilter1i(OIDNFilter filter, const char* name, int value);
+
+// Gets an integer parameter of the filter.
+OIDN_API int oidnGetFilter1i(OIDNFilter filter, const char* name);
+
+// Sets a float parameter of the filter.
+OIDN_API void oidnSetFilter1f(OIDNFilter filter, const char* name, float value);
+
+// Gets a float parameter of the filter.
+OIDN_API float oidnGetFilter1f(OIDNFilter filter, const char* name);
+
+// Sets the progress monitor callback function of the filter.
+OIDN_API void oidnSetFilterProgressMonitorFunction(OIDNFilter filter, OIDNProgressMonitorFunction func, void* userPtr);
+
+// Commits all previous changes to the filter.
+// Must be called before first executing the filter.
+OIDN_API void oidnCommitFilter(OIDNFilter filter);
+
+// Executes the filter.
+OIDN_API void oidnExecuteFilter(OIDNFilter filter);
+
+#if defined(__cplusplus)
+}
+#endif

+ 468 - 0
thirdparty/oidn/include/OpenImageDenoise/oidn.hpp

@@ -0,0 +1,468 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#include <algorithm>
+#include "oidn.h"
+
+namespace oidn {
+
+  // --------------------------------------------------------------------------
+  // Buffer
+  // --------------------------------------------------------------------------
+
+  // Formats for images and other data stored in buffers
+  enum class Format
+  {
+    Undefined = OIDN_FORMAT_UNDEFINED,
+
+    // 32-bit single-precision floating point scalar and vector formats
+    Float  = OIDN_FORMAT_FLOAT,
+    Float2 = OIDN_FORMAT_FLOAT2,
+    Float3 = OIDN_FORMAT_FLOAT3,
+    Float4 = OIDN_FORMAT_FLOAT4,
+  };
+
+  // Access modes for mapping buffers
+  enum class Access
+  {
+    Read         = OIDN_ACCESS_READ,          // read-only access
+    Write        = OIDN_ACCESS_WRITE,         // write-only access
+    ReadWrite    = OIDN_ACCESS_READ_WRITE,    // read and write access
+    WriteDiscard = OIDN_ACCESS_WRITE_DISCARD, // write-only access, previous contents discarded
+  };
+
+  // Buffer object with automatic reference counting
+  class BufferRef
+  {
+  private:
+    OIDNBuffer handle;
+
+  public:
+    BufferRef() : handle(nullptr) {}
+    BufferRef(OIDNBuffer handle) : handle(handle) {}
+
+    BufferRef(const BufferRef& other) : handle(other.handle)
+    {
+      if (handle)
+        oidnRetainBuffer(handle);
+    }
+
+    BufferRef(BufferRef&& other) : handle(other.handle)
+    {
+      other.handle = nullptr;
+    }
+
+    BufferRef& operator =(const BufferRef& other)
+    {
+      if (&other != this)
+      {
+        if (other.handle)
+          oidnRetainBuffer(other.handle);
+        if (handle)
+          oidnReleaseBuffer(handle);
+        handle = other.handle;
+      }
+      return *this;
+    }
+
+    BufferRef& operator =(BufferRef&& other)
+    {
+      std::swap(handle, other.handle);
+      return *this;
+    }
+
+    BufferRef& operator =(OIDNBuffer other)
+    {
+      if (other)
+        oidnRetainBuffer(other);
+      if (handle)
+        oidnReleaseBuffer(handle);
+      handle = other;
+      return *this;
+    }
+
+    ~BufferRef()
+    {
+      if (handle)
+        oidnReleaseBuffer(handle);
+    }
+
+    OIDNBuffer getHandle() const
+    {
+      return handle;
+    }
+
+    operator bool() const
+    {
+      return handle != nullptr;
+    }
+
+    // Maps a region of the buffer to host memory.
+    // If byteSize is 0, the maximum available amount of memory will be mapped.
+    void* map(Access access = Access::ReadWrite, size_t byteOffset = 0, size_t byteSize = 0)
+    {
+      return oidnMapBuffer(handle, (OIDNAccess)access, byteOffset, byteSize);
+    }
+
+    // Unmaps a region of the buffer.
+    // mappedPtr must be a pointer returned by a previous call to map.
+    void unmap(void* mappedPtr)
+    {
+      oidnUnmapBuffer(handle, mappedPtr);
+    }
+  };
+
+  // --------------------------------------------------------------------------
+  // Filter
+  // --------------------------------------------------------------------------
+
+  // Progress monitor callback function
+  typedef bool (*ProgressMonitorFunction)(void* userPtr, double n);
+
+  // Filter object with automatic reference counting
+  class FilterRef
+  {
+  private:
+    OIDNFilter handle;
+
+  public:
+    FilterRef() : handle(nullptr) {}
+    FilterRef(OIDNFilter handle) : handle(handle) {}
+
+    FilterRef(const FilterRef& other) : handle(other.handle)
+    {
+      if (handle)
+        oidnRetainFilter(handle);
+    }
+
+    FilterRef(FilterRef&& other) : handle(other.handle)
+    {
+      other.handle = nullptr;
+    }
+
+    FilterRef& operator =(const FilterRef& other)
+    {
+      if (&other != this)
+      {
+        if (other.handle)
+          oidnRetainFilter(other.handle);
+        if (handle)
+          oidnReleaseFilter(handle);
+        handle = other.handle;
+      }
+      return *this;
+    }
+
+    FilterRef& operator =(FilterRef&& other)
+    {
+      std::swap(handle, other.handle);
+      return *this;
+    }
+
+    FilterRef& operator =(OIDNFilter other)
+    {
+      if (other)
+        oidnRetainFilter(other);
+      if (handle)
+        oidnReleaseFilter(handle);
+      handle = other;
+      return *this;
+    }
+
+    ~FilterRef()
+    {
+      if (handle)
+        oidnReleaseFilter(handle);
+    }
+
+    OIDNFilter getHandle() const
+    {
+      return handle;
+    }
+
+    operator bool() const
+    {
+      return handle != nullptr;
+    }
+
+    // Sets an image parameter of the filter (stored in a buffer).
+    void setImage(const char* name,
+                  const BufferRef& buffer, Format format,
+                  size_t width, size_t height,
+                  size_t byteOffset = 0,
+                  size_t bytePixelStride = 0, size_t byteRowStride = 0)
+    {
+      oidnSetFilterImage(handle, name,
+                         buffer.getHandle(), (OIDNFormat)format,
+                         width, height,
+                         byteOffset,
+                         bytePixelStride, byteRowStride);
+    }
+
+    // Sets an image parameter of the filter (owned by the user).
+    void setImage(const char* name,
+                  void* ptr, Format format,
+                  size_t width, size_t height,
+                  size_t byteOffset = 0,
+                  size_t bytePixelStride = 0, size_t byteRowStride = 0)
+    {
+      oidnSetSharedFilterImage(handle, name,
+                               ptr, (OIDNFormat)format,
+                               width, height,
+                               byteOffset,
+                               bytePixelStride, byteRowStride);
+    }
+
+    // Sets a boolean parameter of the filter.
+    void set(const char* name, bool value)
+    {
+      oidnSetFilter1b(handle, name, value);
+    }
+
+    // Sets an integer parameter of the filter.
+    void set(const char* name, int value)
+    {
+      oidnSetFilter1i(handle, name, value);
+    }
+
+    // Sets a float parameter of the filter.
+    void set(const char* name, float value)
+    {
+      oidnSetFilter1f(handle, name, value);
+    }
+
+    // Gets a parameter of the filter.
+    template<typename T>
+    T get(const char* name);
+
+    // Sets the progress monitor callback function of the filter.
+    void setProgressMonitorFunction(ProgressMonitorFunction func, void* userPtr = nullptr)
+    {
+      oidnSetFilterProgressMonitorFunction(handle, (OIDNProgressMonitorFunction)func, userPtr);
+    }
+
+    // Commits all previous changes to the filter.
+    void commit()
+    {
+      oidnCommitFilter(handle);
+    }
+
+    // Executes the filter.
+    void execute()
+    {
+      oidnExecuteFilter(handle);
+    }
+  };
+
+  // Gets a boolean parameter of the filter.
+  template<>
+  inline bool FilterRef::get(const char* name)
+  {
+    return oidnGetFilter1b(handle, name);
+  }
+
+  // Gets an integer parameter of the filter.
+  template<>
+  inline int FilterRef::get(const char* name)
+  {
+    return oidnGetFilter1i(handle, name);
+  }
+
+  // Gets a float parameter of the filter.
+  template<>
+  inline float FilterRef::get(const char* name)
+  {
+    return oidnGetFilter1f(handle, name);
+  }
+
+  // --------------------------------------------------------------------------
+  // Device
+  // --------------------------------------------------------------------------
+
+  // Device types
+  enum class DeviceType
+  {
+    Default = OIDN_DEVICE_TYPE_DEFAULT, // select device automatically
+
+    CPU = OIDN_DEVICE_TYPE_CPU, // CPU device
+  };
+
+  // Error codes
+  enum class Error
+  {
+    None                = OIDN_ERROR_NONE,                 // no error occurred
+    Unknown             = OIDN_ERROR_UNKNOWN,              // an unknown error occurred
+    InvalidArgument     = OIDN_ERROR_INVALID_ARGUMENT,     // an invalid argument was specified
+    InvalidOperation    = OIDN_ERROR_INVALID_OPERATION,    // the operation is not allowed
+    OutOfMemory         = OIDN_ERROR_OUT_OF_MEMORY,        // not enough memory to execute the operation
+    UnsupportedHardware = OIDN_ERROR_UNSUPPORTED_HARDWARE, // the hardware (e.g. CPU) is not supported
+    Cancelled           = OIDN_ERROR_CANCELLED,            // the operation was cancelled by the user
+  };
+
+  // Error callback function
+  typedef void (*ErrorFunction)(void* userPtr, Error code, const char* message);
+
+  // Device object with automatic reference counting
+  class DeviceRef
+  {
+  private:
+    OIDNDevice handle;
+
+  public:
+    DeviceRef() : handle(nullptr) {}
+    DeviceRef(OIDNDevice handle) : handle(handle) {}
+
+    DeviceRef(const DeviceRef& other) : handle(other.handle)
+    {
+      if (handle)
+        oidnRetainDevice(handle);
+    }
+
+    DeviceRef(DeviceRef&& other) : handle(other.handle)
+    {
+      other.handle = nullptr;
+    }
+
+    DeviceRef& operator =(const DeviceRef& other)
+    {
+      if (&other != this)
+      {
+        if (other.handle)
+          oidnRetainDevice(other.handle);
+        if (handle)
+          oidnReleaseDevice(handle);
+        handle = other.handle;
+      }
+      return *this;
+    }
+
+    DeviceRef& operator =(DeviceRef&& other)
+    {
+      std::swap(handle, other.handle);
+      return *this;
+    }
+
+    DeviceRef& operator =(OIDNDevice other)
+    {
+      if (other)
+        oidnRetainDevice(other);
+      if (handle)
+        oidnReleaseDevice(handle);
+      handle = other;
+      return *this;
+    }
+
+    ~DeviceRef()
+    {
+      if (handle)
+        oidnReleaseDevice(handle);
+    }
+
+    OIDNDevice getHandle() const
+    {
+      return handle;
+    }
+
+    operator bool() const
+    {
+      return handle != nullptr;
+    }
+
+    // Sets a boolean parameter of the device.
+    void set(const char* name, bool value)
+    {
+      oidnSetDevice1b(handle, name, value);
+    }
+
+    // Sets an integer parameter of the device.
+    void set(const char* name, int value)
+    {
+      oidnSetDevice1i(handle, name, value);
+    }
+
+    // Gets a parameter of the device.
+    template<typename T>
+    T get(const char* name);
+
+    // Sets the error callback function of the device.
+    void setErrorFunction(ErrorFunction func, void* userPtr = nullptr)
+    {
+      oidnSetDeviceErrorFunction(handle, (OIDNErrorFunction)func, userPtr);
+    }
+
+    // Returns the first unqueried error code and clears the stored error.
+    // Can be called for a null device as well to check why a device creation failed.
+    Error getError()
+    {
+      return (Error)oidnGetDeviceError(handle, nullptr);
+    }
+
+    // Returns the first unqueried error code and string message, and clears the stored error.
+    // Can be called for a null device as well to check why a device creation failed.
+    Error getError(const char*& outMessage)
+    {
+      return (Error)oidnGetDeviceError(handle, &outMessage);
+    }
+
+    // Commits all previous changes to the device.
+    // Must be called before first using the device (e.g. creating filters).
+    void commit()
+    {
+      oidnCommitDevice(handle);
+    }
+
+    // Creates a new buffer (data allocated and owned by the device).
+    BufferRef newBuffer(size_t byteSize)
+    {
+      return oidnNewBuffer(handle, byteSize);
+    }
+
+    // Creates a new shared buffer (data allocated and owned by the user).
+    BufferRef newBuffer(void* ptr, size_t byteSize)
+    {
+      return oidnNewSharedBuffer(handle, ptr, byteSize);
+    }
+
+    // Creates a new filter of the specified type (e.g. "RT").
+    FilterRef newFilter(const char* type)
+    {
+      return oidnNewFilter(handle, type);
+    }
+  };
+
+  // Gets a boolean parameter of the device.
+  template<>
+  inline bool DeviceRef::get(const char* name)
+  {
+    return oidnGetDevice1b(handle, name);
+  }
+
+  // Gets an integer parameter of the device (e.g. "version").
+  template<>
+  inline int DeviceRef::get(const char* name)
+  {
+    return oidnGetDevice1i(handle, name);
+  }
+
+  // Creates a new device.
+  inline DeviceRef newDevice(DeviceType type = DeviceType::Default)
+  {
+    return DeviceRef(oidnNewDevice((OIDNDeviceType)type));
+  }
+
+} // namespace oidn

+ 23 - 0
thirdparty/oidn/include/OpenImageDenoise/version.h

@@ -0,0 +1,23 @@
+// ======================================================================== //
+// Copyright 2009-2019 Intel Corporation                                    //
+//                                                                          //
+// Licensed under the Apache License, Version 2.0 (the "License");          //
+// you may not use this file except in compliance with the License.         //
+// You may obtain a copy of the License at                                  //
+//                                                                          //
+//     http://www.apache.org/licenses/LICENSE-2.0                           //
+//                                                                          //
+// Unless required by applicable law or agreed to in writing, software      //
+// distributed under the License is distributed on an "AS IS" BASIS,        //
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
+// See the License for the specific language governing permissions and      //
+// limitations under the License.                                           //
+// ======================================================================== //
+
+#pragma once
+
+#define OIDN_VERSION_MAJOR 1
+#define OIDN_VERSION_MINOR 1
+#define OIDN_VERSION_PATCH 0
+#define OIDN_VERSION 10100
+#define OIDN_VERSION_STRING "1.1.0"

+ 214 - 0
thirdparty/oidn/mkl-dnn/LICENSE

@@ -0,0 +1,214 @@
+                                 Apache License
+                           Version 2.0, January 2004
+                        http://www.apache.org/licenses/
+
+   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+   1. Definitions.
+
+      "License" shall mean the terms and conditions for use, reproduction,
+      and distribution as defined by Sections 1 through 9 of this document.
+
+      "Licensor" shall mean the copyright owner or entity authorized by
+      the copyright owner that is granting the License.
+
+      "Legal Entity" shall mean the union of the acting entity and all
+      other entities that control, are controlled by, or are under common
+      control with that entity. For the purposes of this definition,
+      "control" means (i) the power, direct or indirect, to cause the
+      direction or management of such entity, whether by contract or
+      otherwise, or (ii) ownership of fifty percent (50%) or more of the
+      outstanding shares, or (iii) beneficial ownership of such entity.
+
+      "You" (or "Your") shall mean an individual or Legal Entity
+      exercising permissions granted by this License.
+
+      "Source" form shall mean the preferred form for making modifications,
+      including but not limited to software source code, documentation
+      source, and configuration files.
+
+      "Object" form shall mean any form resulting from mechanical
+      transformation or translation of a Source form, including but
+      not limited to compiled object code, generated documentation,
+      and conversions to other media types.
+
+      "Work" shall mean the work of authorship, whether in Source or
+      Object form, made available under the License, as indicated by a
+      copyright notice that is included in or attached to the work
+      (an example is provided in the Appendix below).
+
+      "Derivative Works" shall mean any work, whether in Source or Object
+      form, that is based on (or derived from) the Work and for which the
+      editorial revisions, annotations, elaborations, or other modifications
+      represent, as a whole, an original work of authorship. For the purposes
+      of this License, Derivative Works shall not include works that remain
+      separable from, or merely link (or bind by name) to the interfaces of,
+      the Work and Derivative Works thereof.
+
+      "Contribution" shall mean any work of authorship, including
+      the original version of the Work and any modifications or additions
+      to that Work or Derivative Works thereof, that is intentionally
+      submitted to Licensor for inclusion in the Work by the copyright owner
+      or by an individual or Legal Entity authorized to submit on behalf of
+      the copyright owner. For the purposes of this definition, "submitted"
+      means any form of electronic, verbal, or written communication sent
+      to the Licensor or its representatives, including but not limited to
+      communication on electronic mailing lists, source code control systems,
+      and issue tracking systems that are managed by, or on behalf of, the
+      Licensor for the purpose of discussing and improving the Work, but
+      excluding communication that is conspicuously marked or otherwise
+      designated in writing by the copyright owner as "Not a Contribution."
+
+      "Contributor" shall mean Licensor and any individual or Legal Entity
+      on behalf of whom a Contribution has been received by Licensor and
+      subsequently incorporated within the Work.
+
+   2. Grant of Copyright License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      copyright license to reproduce, prepare Derivative Works of,
+      publicly display, publicly perform, sublicense, and distribute the
+      Work and such Derivative Works in Source or Object form.
+
+   3. Grant of Patent License. Subject to the terms and conditions of
+      this License, each Contributor hereby grants to You a perpetual,
+      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+      (except as stated in this section) patent license to make, have made,
+      use, offer to sell, sell, import, and otherwise transfer the Work,
+      where such license applies only to those patent claims licensable
+      by such Contributor that are necessarily infringed by their
+      Contribution(s) alone or by combination of their Contribution(s)
+      with the Work to which such Contribution(s) was submitted. If You
+      institute patent litigation against any entity (including a
+      cross-claim or counterclaim in a lawsuit) alleging that the Work
+      or a Contribution incorporated within the Work constitutes direct
+      or contributory patent infringement, then any patent licenses
+      granted to You under this License for that Work shall terminate
+      as of the date such litigation is filed.
+
+   4. Redistribution. You may reproduce and distribute copies of the
+      Work or Derivative Works thereof in any medium, with or without
+      modifications, and in Source or Object form, provided that You
+      meet the following conditions:
+
+      (a) You must give any other recipients of the Work or
+          Derivative Works a copy of this License; and
+
+      (b) You must cause any modified files to carry prominent notices
+          stating that You changed the files; and
+
+      (c) You must retain, in the Source form of any Derivative Works
+          that You distribute, all copyright, patent, trademark, and
+          attribution notices from the Source form of the Work,
+          excluding those notices that do not pertain to any part of
+          the Derivative Works; and
+
+      (d) If the Work includes a "NOTICE" text file as part of its
+          distribution, then any Derivative Works that You distribute must
+          include a readable copy of the attribution notices contained
+          within such NOTICE file, excluding those notices that do not
+          pertain to any part of the Derivative Works, in at least one
+          of the following places: within a NOTICE text file distributed
+          as part of the Derivative Works; within the Source form or
+          documentation, if provided along with the Derivative Works; or,
+          within a display generated by the Derivative Works, if and
+          wherever such third-party notices normally appear. The contents
+          of the NOTICE file are for informational purposes only and
+          do not modify the License. You may add Your own attribution
+          notices within Derivative Works that You distribute, alongside
+          or as an addendum to the NOTICE text from the Work, provided
+          that such additional attribution notices cannot be construed
+          as modifying the License.
+
+      You may add Your own copyright statement to Your modifications and
+      may provide additional or different license terms and conditions
+      for use, reproduction, or distribution of Your modifications, or
+      for any such Derivative Works as a whole, provided Your use,
+      reproduction, and distribution of the Work otherwise complies with
+      the conditions stated in this License.
+
+   5. Submission of Contributions. Unless You explicitly state otherwise,
+      any Contribution intentionally submitted for inclusion in the Work
+      by You to the Licensor shall be under the terms and conditions of
+      this License, without any additional terms or conditions.
+      Notwithstanding the above, nothing herein shall supersede or modify
+      the terms of any separate license agreement you may have executed
+      with Licensor regarding such Contributions.
+
+   6. Trademarks. This License does not grant permission to use the trade
+      names, trademarks, service marks, or product names of the Licensor,
+      except as required for reasonable and customary use in describing the
+      origin of the Work and reproducing the content of the NOTICE file.
+
+   7. Disclaimer of Warranty. Unless required by applicable law or
+      agreed to in writing, Licensor provides the Work (and each
+      Contributor provides its Contributions) on an "AS IS" BASIS,
+      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+      implied, including, without limitation, any warranties or conditions
+      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+      PARTICULAR PURPOSE. You are solely responsible for determining the
+      appropriateness of using or redistributing the Work and assume any
+      risks associated with Your exercise of permissions under this License.
+
+   8. Limitation of Liability. In no event and under no legal theory,
+      whether in tort (including negligence), contract, or otherwise,
+      unless required by applicable law (such as deliberate and grossly
+      negligent acts) or agreed to in writing, shall any Contributor be
+      liable to You for damages, including any direct, indirect, special,
+      incidental, or consequential damages of any character arising as a
+      result of this License or out of the use or inability to use the
+      Work (including but not limited to damages for loss of goodwill,
+      work stoppage, computer failure or malfunction, or any and all
+      other commercial damages or losses), even if such Contributor
+      has been advised of the possibility of such damages.
+
+   9. Accepting Warranty or Additional Liability. While redistributing
+      the Work or Derivative Works thereof, You may choose to offer,
+      and charge a fee for, acceptance of support, warranty, indemnity,
+      or other liability obligations and/or rights consistent with this
+      License. However, in accepting such obligations, You may act only
+      on Your own behalf and on Your sole responsibility, not on behalf
+      of any other Contributor, and only if You agree to indemnify,
+      defend, and hold each Contributor harmless for any liability
+      incurred by, or claims asserted against, such Contributor by reason
+      of your accepting any such warranty or additional liability.
+
+   END OF TERMS AND CONDITIONS
+
+   APPENDIX: How to apply the Apache License to your work.
+
+      To apply the Apache License to your work, attach the following
+      boilerplate notice, with the fields enclosed by brackets "{}"
+      replaced with your own identifying information. (Don't include
+      the brackets!)  The text should be enclosed in the appropriate
+      comment syntax for the file format. We also recommend that a
+      file or class name and description of purpose be included on the
+      same "printed page" as the copyright notice for easier
+      identification within third-party archives.
+
+   Copyright {yyyy} {name of copyright owner}
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+       http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License.
+
+   ============================================================================
+
+   Intel MKL-DNN includes components with separate copyright
+   notices and license terms.
+
+   XByak, 3-clause BSD license
+   Copyright (c) 2007 MITSUNARI Shigeo
+   See full copyright notice and license text in src/cpu/xbyak/COPYRIGHT
+
+   gtest, 3-clause BSD license
+   Copyright 2008, Google Inc.
+   See full copyright notice and license text in tests/gtests/gtest/LICENSE

+ 1771 - 0
thirdparty/oidn/mkl-dnn/include/mkldnn.h

@@ -0,0 +1,1771 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MKLDNN_H
+#define MKLDNN_H
+
+#ifndef DOXYGEN_SHOULD_SKIP_THIS
+
+/* All symbols shall be internal unless marked as MKLDNN_API */
+#if defined _WIN32 || defined __CYGWIN__
+#   define MKLDNN_HELPER_DLL_IMPORT __declspec(dllimport)
+#   define MKLDNN_HELPER_DLL_EXPORT __declspec(dllexport)
+#else
+#   if __GNUC__ >= 4
+#       define MKLDNN_HELPER_DLL_IMPORT __attribute__ ((visibility ("default")))
+#       define MKLDNN_HELPER_DLL_EXPORT __attribute__ ((visibility ("default")))
+#   else
+#       define MKLDNN_HELPER_DLL_IMPORT
+#       define MKLDNN_HELPER_DLL_EXPORT
+#   endif
+#endif
+
+#ifdef MKLDNN_DLL
+#   ifdef MKLDNN_DLL_EXPORTS
+#       define MKLDNN_API MKLDNN_HELPER_DLL_EXPORT
+#   else
+#       define MKLDNN_API MKLDNN_HELPER_DLL_IMPORT
+#   endif
+#else
+#   define MKLDNN_API
+#endif
+
+#if defined (__GNUC__)
+#   define MKLDNN_DEPRECATED __attribute__((deprecated))
+#elif defined(_MSC_VER)
+#   define MKLDNN_DEPRECATED __declspec(deprecated)
+#else
+#   define MKLDNN_DEPRECATED
+#endif
+
+#include "mkldnn_types.h"
+#include "mkldnn_version.h"
+#endif /* DOXYGEN_SHOULD_SKIP_THIS */
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/** @addtogroup c_api C API
+ * @{ */
+
+/** @addtogroup c_api_primitive Primitive operations
+ * @{ */
+
+/** @addtogroup c_api_primitive_common Common primitive operations
+ * @{ */
+
+/** Creates a primitive descriptor @p iterator for given @p op_desc, @p attr,
+ * @p engine, and optionally a hint primitive descriptor from forward
+ * propagation (required for backward propagation). Pass @c NULL for forward
+ * propagation.
+ */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_create(
+        mkldnn_primitive_desc_iterator_t *iterator,
+        const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr,
+        mkldnn_engine_t engine,
+        const_mkldnn_primitive_desc_t hint_forward_primitive_desc);
+
+/** Iterates over primitive descriptors. Returns #mkldnn_iterator_ends if no
+ * more primitive descriptors are available. */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_next(
+        mkldnn_primitive_desc_iterator_t iterator);
+
+/** Fetches the current primitive descriptor.
+ *
+ * @note
+ *     The user should delete the fetched primitive descriptor using
+ *     mkldnn_primitive_desc_destroy() once it is no longer needed. */
+mkldnn_primitive_desc_t MKLDNN_API mkldnn_primitive_desc_iterator_fetch(
+        const_mkldnn_primitive_desc_iterator_t iterator);
+
+/** Deletes a primitive descriptor @p iterator */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_iterator_destroy(
+        mkldnn_primitive_desc_iterator_t iterator);
+
+/** Creates a @p primitive_desc using @p op_desc, @p attr, @p engine, and
+ * optionally a hint primitive descriptor from forward propagation. The call is
+ * equivalent to creating a primitive descriptor iterator, immediately fetching
+ * a primitive descriptor, and then destroying the iterator. */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_create(
+        mkldnn_primitive_desc_t *primitive_desc,
+        const_mkldnn_op_desc_t op_desc, const_mkldnn_primitive_attr_t attr,
+        mkldnn_engine_t engine,
+        const_mkldnn_primitive_desc_t hint_forward_primitive_desc);
+
+/** Makes a copy of a @p primitive_desc. */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_clone(
+        mkldnn_primitive_desc_t *primitive_desc,
+        const_mkldnn_primitive_desc_t existing_primitive_desc);
+
+/** Returns a constant reference to the attribute of a @p primitive_desc.
+ *
+ * @warning
+ *      The user should not destroy the obtained @p attr.
+ *
+ * @warning
+ *      The lifetime of an @p attr is the same as that of a @p primitive_desc,
+ *      so it is illegal to use the @p attr once @p primitive_desc has been
+ *      destroyed. */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_get_attr(
+        const_mkldnn_primitive_desc_t primitive_desc,
+        const_mkldnn_primitive_attr_t *attr);
+
+/** Deletes a @p primitive_desc. */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_destroy(
+        mkldnn_primitive_desc_t primitive_desc);
+
+/** Queries primitive descriptor
+ *
+ * One of the most typical use cases is to query a convolution primitive
+ * descriptor created with source, weights, and destination formats equal
+ * to #mkldnn_format_tag_any about the corresponding memory descriptors
+ * (@p what equals #mkldnn_query_src_md, #mkldnn_query_weights_md, and
+ * #mkldnn_query_dst_md respectively) to be able to prepare memory and
+ * create reorders if required.
+ *
+ * Another quite typical use case is to query an operation primitive
+ * descriptor for a workspace (@p what equals #mkldnn_query_workspace_md).
+ * The returned status #mkldnn_not_required indicates that a workspace is
+ * not required.
+ *
+ * A few other possibilities:
+ *  - query an operation primitive descriptor for the underlying operation
+ *    descriptor (#mkldnn_query_convolution_d, #mkldnn_query_eltwise_d,
+ *    #mkldnn_query_rnn_d, etc.)
+ *  - query an operation primitive descriptor for the implementation
+ *    information string (#mkldnn_query_impl_info_str)
+ *  - query an operation primitive descriptor for the number of inputs and
+ *    outputs (#mkldnn_query_num_of_inputs_s32 and
+ *    #mkldnn_query_num_of_outputs_s32 respectively)
+ *
+ * @sa mkldnn_query_t for more options
+ */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_desc_query(
+        const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what,
+        int index, void *result);
+
+/** Queries primitive descriptor for memory descriptor
+ *
+ * @returns NULL in case of any error.
+ *
+ * This is just a specialized version of mkldnn_primitive_desc_query
+ * used for convenience.
+ */
+const mkldnn_memory_desc_t MKLDNN_API *mkldnn_primitive_desc_query_md(
+        const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what,
+        int index);
+
+/** Queries primitive descriptor for signed 32bit int
+ *
+ * @returns 0 in case of any error (in particular if the queried entity is
+ * not of type int32_t). Note that 0 might also be the actual returned
+ * value.
+ *
+ * This is just a specialized version of mkldnn_primitive_desc_query
+ * used for convenience.
+ */
+int MKLDNN_API mkldnn_primitive_desc_query_s32(
+        const_mkldnn_primitive_desc_t primitive_desc, mkldnn_query_t what,
+        int index);
+
+/** Creates a @p primitive using a @p primitive_desc descriptor. */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_create(
+        mkldnn_primitive_t *primitive,
+        const_mkldnn_primitive_desc_t primitive_desc);
+
+/** Executes a @p primitive using a @p stream, and @p nargs arguments
+ * @p args. */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_execute(
+        const_mkldnn_primitive_t primitive, mkldnn_stream_t stream,
+        int nargs, const mkldnn_exec_arg_t *args);
+
+/** Retrieves a reference to the @p primitive_desc descriptor of given @p
+ * primitive.
+ *
+ * @warning
+ *     The returned object must not be destroyed by the user. The @c const
+ *     qualifier of the returned object prevents such attempts. */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_get_primitive_desc(
+        const_mkldnn_primitive_t primitive,
+        const_mkldnn_primitive_desc_t *primitive_desc);
+
+/** Deletes a @p primitive. */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_destroy(
+        mkldnn_primitive_t primitive);
+
+/** @} */
+
+/** @addtogroup c_api_attributes Attributes
+ * An extension for controlling primitive behavior.
+ * @{ */
+
+/** Creates an empty (default) @p attr attribute. All the parameters are set to
+ * default values.
+ *
+ * An empty attribute is used in primitive descriptor creation whenever it
+ * is not passed explicitly, e.g. in mkldnn_primitive_desc_create.
+ */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_create(
+        mkldnn_primitive_attr_t *attr);
+
+/** Makes a copy of an @p existing_attr. */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_clone(
+        mkldnn_primitive_attr_t *attr,
+        const_mkldnn_primitive_attr_t existing_attr);
+
+/** Deletes an @p attr. */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_destroy(
+        mkldnn_primitive_attr_t attr);
+
+/** Returns the scratchpad @p mode set in the attribute @p attr */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_scratchpad_mode(
+        const_mkldnn_primitive_attr_t attr, mkldnn_scratchpad_mode_t *mode);
+
+/** Sets scratchpad @p mode.
+ *
+ * The possible values are: #mkldnn_scratchpad_mode_library (default) and
+ * #mkldnn_scratchpad_mode_user. */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_scratchpad_mode(
+        mkldnn_primitive_attr_t attr, mkldnn_scratchpad_mode_t mode);
+
+/** Returns @p count, correspondence scale @p mask, and a pointer to a constant
+ * floating point array of output @p scales for given @p attr, previously set
+ * by mkldnn_primitive_attr_set_output_scales.
+ *
+ * @warning
+ *      The @p scales array points to the internal @p attr field, so the user
+ *      should not modify or destroy @p scales.
+ *
+ * @warning
+ *      The lifetime of @p scales is the same as that of the @p attr to which it
+ *      belongs, so it is illegal to use @p scales after @p attr is destroyed.
+ */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_output_scales(
+        const_mkldnn_primitive_attr_t attr, mkldnn_dim_t *count, int *mask,
+        const float **scales);
+
+/** Sets output @p scales for primitive operations. The number of elements @p
+ * count and correspondence scale @p mask are stored for future use.
+ *
+ * The @p mask argument defines the correspondence between the output tensor
+ * dimensions and the @p scales array. Set the i-th bit of @p mask to 1 to use a
+ * dedicated scaling factor for each slice of the output tensor over the i-th
+ * dimension. Set @p mask to 0 to use a common scaling factor for the whole
+ * output tensor.
+ *
+ * @note
+ *      The dimension order is always native and does not depend on the actual
+ *      layout used. Examples:
+ *       - 2D dimensional data the order of dimensions is always: (n, c)
+ *       - 4D dimensional data the order is always: (n, c, h, w)
+ *       - 5D dimensional weights the order is always: (g, oc, ic, kh, kw)
+ *
+ * Example usage:
+ * @code
+ *      int mb = 32, oc = 32, oh = 14, ow = 14; // convolution output params
+ *      float scales[oc] = { ... }; // unique output scales per output channel
+ *      int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ...
+ *
+ *      mkldnn_convolution_desc_t cd; // create & configure convolution op_desc
+ *
+ *      mkldnn_primitive_attr_t attr;
+ *      mkldnn_primitive_attr_create(&attr);  // create default attributes
+ *      mkldnn_primitive_attr_set_output_scales(attr, oc, 1 << oc_dim, scales);
+ *
+ *      mkldnn_primitive_desc_t cpd;
+ *      mkldnn_primitive_desc_create(&cpd, &cd, attr, NULL);
+ * @endcode
+ *
+ * @note
+ *      There is no way to check that @p count corresponds to @p mask until an
+ *      actual primitive descriptor is created, so it is the user's
+ *      responsibility to set proper values. The following formula must hold:
+ *
+ *      \f[count = \prod\limits_{d \in mask} output.dims[d]\f]
+ */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_output_scales(
+        mkldnn_primitive_attr_t attr, mkldnn_dim_t count, int mask,
+        const float *scales);
+
+/** Returns @p post_ops for given @p attr.
+ *
+ * @warning
+ *      @p post_ops points to the internal @p attr field, so the user should not
+ *      modify or destroy @p post_ops. Also, the lifetime of @p post_ops is the
+ *      same as that of the @p attr it belongs to, so it is illegal to use @p
+ *      post_ops after @p attr has been destroyed.
+ */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_get_post_ops(
+        const_mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t *post_ops);
+
+/** Sets configured @p post_ops to an attribute @p attr for future use (when
+ * primitive descriptor is being created).
+ *
+ * @note
+ *      At this point in time, there is no way to check whether the primitive
+ *      descriptor does or does not support a given sequence of post operations.
+ *      Therefore the user should handle an error that might occur at the
+ *      mkldnn_primitive_desc_create call.
+ */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_post_ops(
+        mkldnn_primitive_attr_t attr, const_mkldnn_post_ops_t post_ops);
+
+/** @addtogroup c_api_attributes_post_ops Sequence of post operations
+ * An extension for performing extra operations after a base operation.
+ * @{ */
+
+/** Creates an empty sequence of post operations @p post_ops. */
+mkldnn_status_t MKLDNN_API mkldnn_post_ops_create(mkldnn_post_ops_t *post_ops);
+
+/** Deletes a @p post_ops sequence. */
+mkldnn_status_t MKLDNN_API mkldnn_post_ops_destroy(mkldnn_post_ops_t post_ops);
+
+/** Returns the @p length of post operations for given @p post_ops. */
+int MKLDNN_API mkldnn_post_ops_len(const_mkldnn_post_ops_t post_ops);
+
+/** Returns the type of post operation with index @p index in given
+ * @p post_ops. In case of error, returns #mkldnn_undefined_primitive. */
+mkldnn_primitive_kind_t MKLDNN_API mkldnn_post_ops_get_kind(
+        const_mkldnn_post_ops_t post_ops, int index);
+
+/** Appends accumulation (sum) post operation to the @p post_ops. Prior to
+ * accumulating the result, the previous value would be multiplied by @p scale.
+ *
+ * The kind of this post operation is #mkldnn_sum.
+ *
+ * This feature might improve performance for cases like residual learning
+ * blocks, where the result of convolution is accumulated to the previously
+ * computed activations. The parameter @p scale might be extreme for the
+ * integer-based computations when the result and previous activations have
+ * different logical scaling factors.
+ *
+ * In the simplest case when the accumulation is the only post operation, the
+ * computations would be:
+ * dst[] <- scale * dst[] + op(...) // instead of dst[] <- op(...)
+ *
+ * @note
+ *      This post operation (as well as all the others) disregards the original
+ *      layout of the destination; that is, the layout of the original
+ *      destination is expected to be the same as the layout of the stored
+ *      destination.
+ */
+mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_sum(
+        mkldnn_post_ops_t post_ops, float scale);
+
+/** Gets the parameters of the accumulation (sum) post operation with index
+ * @p index in the sequence of @p post_ops.
+ *
+ * @note
+ *      If index @p index would not correspond to the accumulation post
+ *      operation, the function returns #mkldnn_invalid_arguments.
+ */
+mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_sum(
+        const_mkldnn_post_ops_t post_ops, int index, float *scale);
+
+/** Appends eltwise post operation to the @p post_ops with given parameters
+ * @p kind, @p alpha, and @p beta (@sa mkldnn_eltwise_forward_desc_init and
+ * mkldnn_eltwise_desc_t).
+ *
+ * The kind of this post operation is #mkldnn_eltwise.
+ *
+ * In the simplest case when the eltwise is the only post operation, the
+ * computations would be:
+ * dst[] <- scale * eltwise_op ( op(...) ) // instead of dst[] <- op(...)
+ * where eltwise_op is configured with the given parameters.
+ */
+mkldnn_status_t MKLDNN_API mkldnn_post_ops_append_eltwise(
+        mkldnn_post_ops_t post_ops, float scale, mkldnn_alg_kind_t alg,
+        float alpha, float beta);
+
+/** Gets the eltwise parameters of the post operation with index @p index in
+ * the sequence of @p post_ops.
+ */
+mkldnn_status_t MKLDNN_API mkldnn_post_ops_get_params_eltwise(
+        const_mkldnn_post_ops_t post_ops, int index, float *scale,
+        mkldnn_alg_kind_t *alg, float *alpha, float *beta);
+
+/** @} */
+
+/** @} */
+
+/** @addtogroup c_api_memory Memory
+ * A primitive to describe and store data.
+ *
+ * The library supports various data types and formats. Memory hierarchy
+ * consists of three levels of abstraction:
+ * 1. **Memory descriptor** -- engine agnostic logical description of data
+ *      (number of dimensions, dimensions themselves, and data type), and
+ *      optionally the format/layout that describes the physical representation
+ *      of data in memory. If the format is not known yet, one can pass
+ *      #mkldnn_format_tag_any. This approach is used to allow compute-intensive
+ *      primitives to specify the most appropriate format on their own with
+ *      users required to reorder the data if the incoming format doesn't match
+ *      the primitive's selection. Memory descriptor can be initialized with
+ *      mkldnn_memory_desc_init_by_tag() or mkldnn_memory_desc_init_by_strides()
+ *      functions, or by directly filling the mkldnn_memory_desc_t structure.
+ *      The latter requires deep knowledge of how the physical data
+ *      representation is mapped to the structure.
+ *      The @ref understanding_memory_formats topic should shed some light on
+ *      that.
+ *      For the fully defined memory descriptors (i.e. where the format kind is
+ *      not equal to #mkldnn_format_kind_any) a user can the size, using the
+ *      mkldnn_memory_desc_get_size() function. As described in
+ *      @ref understanding_memory_formats, the size of data sometimes cannot
+ *      be computed as the product of dimensions times the size of the data
+ *      type. So users are encouraged to use this function for better code
+ *      portability.
+ *      Two memory descriptors can be compared with mkldnn_memory_desc_equal().
+ *      The comparison is especially useful when checking whether a primitive
+ *      requires reorder from the user's data format to the primitive's format.
+ * 2. **Memory** -- an engine-specific object that handles the data and its
+ *      description (a memory descriptor). For CPU enigne, the data handle is
+ *      simply a pointer to @c void. The data handle can be queried using
+ *      mkldnn_memory_get_data_handle() and set using
+ *      mkldnn_memory_set_data_handle(). The latter function always sets the
+ *      memory in the padding region to zero, which is the invariant maintained
+ *      by all the primitives in Intel MKL-DNN.
+ *      See @ref understanding_memory_formats for more details.
+ *      A memory can be created using mkldnn_memory_create() function.
+ *      A memory can also be queried for the underlying memory descriptor and
+ *      engine using mkldnn_memory_get_memory_desc() and
+ *      mkldnn_memory_get_engine() functions.
+ *
+ * Along with ordinary memory with all dimensions being positive, Intel
+ * MKL-DNN supports *zero-volume* memory with one or more dimensions set to
+ * zero. This is to support the NumPy\* convention.
+ * If a *zero-volume* memory is passed to a primitive, the primitive does
+ * not perform any computations on this memory. For example:
+ *  - Convolution with `(0 batch, 3 input channels, 13 height, 13 width)`
+ *    source and `(16 output channels, 3 inputs, channel, 3 height, 3 width)`
+ *    weights would produce `(0 batch, 16 output channels, 11 height, 11 width)`
+ *    destination (assuming strides are `1` and paddings are zero) and perform
+ *    zero multiply-add operations.
+ *  - Concatenation of three memories of shapes `(3, 4, 13, 13)`,
+ *    `(3, 0, 13, 13)`, and `(3, 1, 13, 13)` along the second axis would produce
+ *    the output of the shape `(3, 5, 13, 13)`, effectively ignoring the second
+ *    input (however, if the user created a concatenation primitive descriptor
+ *    with three inputs they should also provide all three memories to the
+ *    concatenation primitive, including the one with zero second dimension).
+ *  - However, Intel MKL-DNN would return an error when attempting to create a
+ *    convolution with *zero-volume* memory passed for weights because such a
+ *    convolution is not well-defined:
+ *    ~~~
+ *    dst(1, 16, 11, 11) <-- src(1, 0, 13, 13) (*) wei(16, 0, 3, 3)
+ *    ~~~
+ *    Should the values in the destination be zeroes or just not accessed at
+ *    all? Moreover, backward pass w.r.t. weights in such cases is also not
+ *    well-defined.
+ *
+ *  Data handle of *zero-volume* memory is never accessed and hence can be
+ *  unset (NULL in case of CPU engine).
+ *
+ * @sa @ref understanding_memory_formats
+ * @{ */
+
+/** Initializes a @p memory_desc memory descriptor using @p ndims, @p dims, @p
+ * data_type, and @p strides.
+ *
+ * The @p strides might be NULL, which means the order of physical dimensions
+ * is the same as the order of logical ones.
+ *
+ * @note The logical order of dimensions is defined by a primitive that
+ *       consumes the memory.
+ */
+mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_by_strides(
+        mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims,
+        mkldnn_data_type_t data_type, const mkldnn_dims_t strides);
+
+/** Initializes a @p memory_desc memory descriptor using @p ndims, @p dims, @p
+ * data_type, and format @p tag.
+ *
+ * @p tag can be #mkldnn_format_tag_any, which allows a primitive to define
+ * the appropriate memory format. In this case, the @p format_kind would be set
+ * to #mkldnn_format_kind_any */
+mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_by_tag(
+        mkldnn_memory_desc_t *memory_desc, int ndims, const mkldnn_dims_t dims,
+        mkldnn_data_type_t data_type, mkldnn_format_tag_t tag);
+
+/** Initializes a @p memory_desc for a given @p parent_memory_desc, with
+ * @p dims sizes and @p offsets. May fail if layout used does not allow
+ * obtain desired submemory. In this case consider using `extract` or `insert`
+ * primitive */
+mkldnn_status_t MKLDNN_API mkldnn_memory_desc_init_submemory(
+        mkldnn_memory_desc_t *memory_desc,
+        const mkldnn_memory_desc_t *parent_memory_desc,
+        const mkldnn_dims_t dims, const mkldnn_dims_t offsets);
+
+/** Compares two memory descriptors.
+ * @return 1 if the descriptors are the same.
+ * @return 0 if the descriptors are different.
+ *
+ * Use this function to identify whether a reorder is required between the
+ * two memories */
+int MKLDNN_API mkldnn_memory_desc_equal(
+        const mkldnn_memory_desc_t *lhs,
+        const mkldnn_memory_desc_t *rhs);
+
+/** Returns the size (in bytes) that is required for given @p memory_desc */
+size_t MKLDNN_API mkldnn_memory_desc_get_size(
+        const mkldnn_memory_desc_t *memory_desc);
+
+/** Creates a memory for given @p memory_desc and @p engine. Also sets handle
+ * to @p native_handle.
+ * The @p native_handle can:
+ * - point to the user allocated memory, i.e. valid handle. In this case the
+ *   library doesn't own allocated memory.
+ * - be MKLDNN_NATIVE_HANDLE_ALLOCATE to ask the library to allocate and
+ *   attach memory. In this case the library owns allocated memory.
+ * - be MKLDNN_NATIVE_HANDLE_NONE to create mkldnn_memory w/o attached memory.
+ */
+mkldnn_status_t MKLDNN_API mkldnn_memory_create(mkldnn_memory_t *memory,
+        const mkldnn_memory_desc_t *memory_desc, mkldnn_engine_t engine,
+        void *native_handle);
+
+/** Returns a @p memory_desc associated with @p memory. */
+mkldnn_status_t MKLDNN_API mkldnn_memory_get_memory_desc(
+        const_mkldnn_memory_t memory,
+        const mkldnn_memory_desc_t **memory_desc);
+
+/** Returns an @p engine associated with @p memory. */
+mkldnn_status_t MKLDNN_API mkldnn_memory_get_engine(
+        const_mkldnn_memory_t memory, mkldnn_engine_t *engine);
+
+/** For a @p memory, returns the data @p handle.
+ *
+ * For the CPU engine, the data handle is a pointer to the actual data. */
+mkldnn_status_t MKLDNN_API mkldnn_memory_get_data_handle(
+        const_mkldnn_memory_t memory, void **handle);
+
+/** For a @p memory, sets the data @p handle. */
+mkldnn_status_t MKLDNN_API mkldnn_memory_set_data_handle(
+        mkldnn_memory_t memory, void *handle);
+
+/** Deletes a @p memory. */
+mkldnn_status_t MKLDNN_API mkldnn_memory_destroy(mkldnn_memory_t memory);
+
+/** @} */
+
+/** @addtogroup c_api_reorder Reorder
+ * A primitive to copy data between memory formats.
+ * @{ */
+
+/** Initializes a @p reorder_primitive_desc using the description of the source
+ * (@p src_engine and @p src_md) and destination (@p dst_engine and @p dst_md)
+ * memory, and an @p attr attribute.
+ *
+ * Inputs:
+ *  - input (#mkldnn_query_src_md, 0)
+ *
+ * Outputs:
+ *  - output (#mkldnn_query_dst_md, 0)
+ */
+mkldnn_status_t MKLDNN_API mkldnn_reorder_primitive_desc_create(
+        mkldnn_primitive_desc_t *reorder_primitive_desc,
+        mkldnn_engine_t src_engine, const mkldnn_memory_desc_t *src_md,
+        mkldnn_engine_t dst_engine, const mkldnn_memory_desc_t *dst_md,
+        const_mkldnn_primitive_attr_t attr);
+
+/** @} */
+
+/** @addtogroup c_api_concat Concat
+ * A primitive to concatenate data by arbitrary dimension.
+ * @{ */
+
+/** Creates out-of-place @p concat_primitive_desc for concatenation of @p n
+ * inputs by @p concat_dimension with resulting @p output_desc memory
+ * descriptor. @p output_desc can be NULL or specified with the
+ * #mkldnn_format_kind_any format kind -- in this case, the appropriate memory
+ * format would be chosen automatically.
+ *
+ * Inputs:
+ *  - input 0 (#mkldnn_query_src_md, 0)
+ *  - input 1 (#mkldnn_query_src_md, 1)
+ *  - ...
+ *  - input @p n - 1 (#mkldnn_query_src_md, @p n - 1)
+ *
+ * Outputs:
+ *  - output (#mkldnn_query_dst_md, 0)
+ */
+mkldnn_status_t MKLDNN_API mkldnn_concat_primitive_desc_create(
+        mkldnn_primitive_desc_t *concat_primitive_desc,
+        const mkldnn_memory_desc_t *dst_md,
+        int n, int concat_dimension,
+        const mkldnn_memory_desc_t *src_mds,
+        const_mkldnn_primitive_attr_t attr,
+        mkldnn_engine_t engine);
+
+/** @} */
+
+/** @addtogroup c_api_sum Sum
+ * A primitive to sum data.
+ * @{ */
+
+/** Creates out-of-place @p sum_primitive_desc for sum of @p n
+ * inputs multiplied by scale with resulting @p output_desc memory
+ * descriptor. @p output_desc can be NULL or specified with the
+ * #mkldnn_format_kind_any format kind -- in this case, the appropriate memory
+ * format would be chosen automatically.
+ *
+ * Inputs:
+ *  - src 0 (#mkldnn_query_src_md, 0)
+ *  - src 1 (#mkldnn_query_src_md, 1)
+ *  - ...
+ *  - src @p n - 1 (#mkldnn_query_src_md, @p n - 1)
+ *
+ * Outputs:
+ *  - output (#mkldnn_query_dst_md, 0)
+ */
+mkldnn_status_t MKLDNN_API mkldnn_sum_primitive_desc_create(
+        mkldnn_primitive_desc_t *sum_primitive_desc,
+        const mkldnn_memory_desc_t *dst_mds,
+        int n, const float *scales,
+        const mkldnn_memory_desc_t *src_mds,
+        const_mkldnn_primitive_attr_t attr,
+        mkldnn_engine_t engine);
+
+/** @} */
+
+/** @addtogroup c_api_convolution Convolution
+ * A primitive to compute convolution using different algorithms.
+ *
+ * \f[dst[n][oc][oh][ow]  =
+ *     \sum_{kw=0}^{KW}\sum_{kh=0}^{KH}\sum_{ic=0}^{IC}
+ *     src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw]
+ *     \cdot weights[g][oc][ic][kh][kw]
+ *     + bias[g][oc],\f]
+ *
+ * where size of output spatial domain is given by
+ * \f$ OH = \left\lfloor{\frac{IH - KH + p_l[0] + p_r[0]}{s_h}}
+ *          \right\rfloor + 1\f$,
+ * \f$ OW = \left\lfloor{\frac{IW - KW + p_l[1] + p_r[1]}{s_w}}
+ *          \right\rfloor + 1\f$,
+ *
+ * and summation is carried over input channels \f$ic\f$ in
+ * group \f$g\f$, and \f$s_h, s_w\f$ are @p strides and
+ * \f$p_l, p_r\f$ are @p padding_l and @p padding_r.
+ * @{ */
+
+/** Initializes a convolution descriptor @p conv_desc for forward propagation
+ * using @p prop_kind (possible values are #mkldnn_forward_training and
+ * #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides, @p
+ * padding_l, @p padding_r, and @p padding_kind. In order to create a
+ * convolution without bias, @p bias_desc should either be @c NULL or point to
+ * a descriptor with memory format kind equal to #mkldnn_format_kind_undef.
+ *
+ * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric.
+ *
+ * @note Memory descriptors are allowed to be initialized with
+ *       #mkldnn_format_kind_any value of @p format_kind.
+ *
+ * Inputs:
+ *  - src (#mkldnn_query_src_md, 0)
+ *  - weights (#mkldnn_query_weights_md, 0)
+ *  - bias (#mkldnn_query_weights_md, 1), if created with bias
+ *
+ * Outputs:
+ *  - dst (#mkldnn_query_dst_md, 0)
+ */
+mkldnn_status_t MKLDNN_API mkldnn_convolution_forward_desc_init(
+        mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind,
+        mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc,
+        const mkldnn_memory_desc_t *weights_desc,
+        const mkldnn_memory_desc_t *bias_desc,
+        const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides,
+        const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r,
+        mkldnn_padding_kind_t padding_kind);
+
+/** Initializes a dilated convolution descriptor @p conv_desc for forward
+ * propagation using @p prop_kind (possible values are #mkldnn_forward_training
+ * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides,
+ * @p dilates, @p padding_l, @p padding_r, and @p padding_kind.
+ * In order to create a dilated convolution without bias, @p bias_desc
+ * should either be @c NULL or point to a descriptor with memory format kind
+ * equals #mkldnn_format_kind_undef.
+ *
+ * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric.
+ *
+ * @note Memory descriptors are allowed to be initialized with
+ *       #mkldnn_format_kind_any value of @p format_kind.
+ *
+ * Inputs:
+ *  - src (#mkldnn_query_src_md, 0)
+ *  - weights (#mkldnn_query_weights_md, 0)
+ *  - bias (#mkldnn_query_weights_md, 1), if created with bias
+ *
+ * Outputs:
+ *  - dst (#mkldnn_query_dst_md, 0)
+ */
+mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_forward_desc_init(
+        mkldnn_convolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind,
+        mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc,
+        const mkldnn_memory_desc_t *weights_desc,
+        const mkldnn_memory_desc_t *bias_desc,
+        const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides,
+        const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l,
+        const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
+
+/** Initializes a convolution descriptor @p conv_desc for backward propagation
+ * with respect to data using @p alg_kind, memory descriptors, @p strides, @p
+ * padding_l, @p padding_r, and @p padding_kind.
+ *
+ * @note Memory descriptors are allowed to be initialized with
+ *       #mkldnn_format_kind_any value of @p format_kind.
+ *
+ * Inputs:
+ *  - diff_dst (#mkldnn_query_diff_dst_md, 0)
+ *  - weights (#mkldnn_query_weights_md, 0)
+ *
+ * Outputs:
+ *  - diff_src (#mkldnn_query_diff_src_md, 0)
+ */
+mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_data_desc_init(
+        mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
+        const mkldnn_memory_desc_t *diff_src_desc,
+        const mkldnn_memory_desc_t *weights_desc,
+        const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
+        const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r,
+        mkldnn_padding_kind_t padding_kind);
+
+/** Initializes a dilated convolution descriptor @p conv_desc for backward
+ * propagation with respect to data using @p alg_kind, memory descriptors, @p
+ * strides, @p dilates @p padding_l, @p padding_r, and @p padding_kind.
+ *
+ * @note Memory descriptors are allowed to be initialized with
+ *       #mkldnn_format_kind_any value of @p format_kind.
+ *
+ * Inputs:
+ *  - diff_dst (#mkldnn_query_diff_dst_md, 0)
+ *  - weights (#mkldnn_query_weights_md, 0)
+ *
+ * Outputs:
+ *  - diff_src (#mkldnn_query_diff_src_md, 0)
+ */
+mkldnn_status_t MKLDNN_API mkldnn_dilated_convolution_backward_data_desc_init(
+        mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
+        const mkldnn_memory_desc_t *diff_src_desc,
+        const mkldnn_memory_desc_t *weights_desc,
+        const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
+        const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l,
+        const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
+
+/** Initializes a convolution descriptor @p conv_desc for backward propagation
+ * with respect to weights using @p alg_kind, memory descriptors, @p strides,
+ * @p padding_l, @p padding_r, and @p padding_kind.
+ *
+ * @note Memory descriptors are allowed to be initialized with
+ *       #mkldnn_format_kind_any value of @p format_kind.
+ *
+ * Inputs:
+ *  - src (#mkldnn_query_src_md, 0)
+ *  - diff_dst (#mkldnn_query_diff_dst_md, 0)
+ *
+ * Outputs:
+ *  - diff_weights (#mkldnn_query_diff_weights_md, 0)
+ *  - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias
+ */
+mkldnn_status_t MKLDNN_API mkldnn_convolution_backward_weights_desc_init(
+        mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
+        const mkldnn_memory_desc_t *src_desc,
+        const mkldnn_memory_desc_t *diff_weights_desc,
+        const mkldnn_memory_desc_t *diff_bias_desc,
+        const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
+        const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r,
+        mkldnn_padding_kind_t padding_kind);
+
+/** Initializes a convolution descriptor @p conv_desc for backward propagation
+ * with respect to weights using @p alg_kind, memory descriptors, @p strides,
+ * @p dilates @p padding_l, @p padding_r, and @p padding_kind.
+ *
+ * @note Memory descriptors are allowed to be initialized with
+ *       #mkldnn_format_kind_any value of @p format_kind.
+ *
+ * Inputs:
+ *  - src (#mkldnn_query_src_md, 0)
+ *  - diff_dst (#mkldnn_query_diff_dst_md, 0)
+ *
+ * Outputs:
+ *  - diff_weights (#mkldnn_query_diff_weights_md, 0)
+ *  - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias
+ */
+mkldnn_status_t MKLDNN_API
+mkldnn_dilated_convolution_backward_weights_desc_init(
+        mkldnn_convolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
+        const mkldnn_memory_desc_t *src_desc,
+        const mkldnn_memory_desc_t *diff_weights_desc,
+        const mkldnn_memory_desc_t *diff_bias_desc,
+        const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
+        const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l,
+        const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
+
+/** @} */
+
+/** @addtogroup c_api_deconvolution Deconvolution
+ * A primitive to compute deconvolution using different algorithms.
+ *
+ * @{ */
+
+
+/** Initializes a deconvolution descriptor @p deconv_desc for forward
+ * propagation using @p prop_kind (possible values are #mkldnn_forward_training
+ * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides,
+ * @p padding_l, @p padding_r, and @p padding_kind. In order to create a
+ * deconvolution without bias, @p bias_desc should either be @c NULL or point to
+ * a descriptor with memory format kind equals #mkldnn_format_kind_undef.
+ *
+ * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric.
+ *
+ * @note Memory descriptors are allowed to be initialized with
+ *       #mkldnn_format_kind_any value of @p format_kind.
+ *
+ * Inputs:
+ *  - src (#mkldnn_query_src_md, 0)
+ *  - weights (#mkldnn_query_weights_md, 0)
+ *  - bias (#mkldnn_query_weights_md, 1), if created with bias
+ *
+ * Outputs:
+ *  - dst (#mkldnn_query_dst_md, 0)
+ */
+mkldnn_status_t MKLDNN_API mkldnn_deconvolution_forward_desc_init(
+        mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind,
+        mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc,
+        const mkldnn_memory_desc_t *weights_desc,
+        const mkldnn_memory_desc_t *bias_desc,
+        const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides,
+        const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r,
+        mkldnn_padding_kind_t padding_kind);
+
+/** Initializes a dilated deconvolution descriptor @p deconv_desc for forward
+ * propagation using @p prop_kind (possible values are #mkldnn_forward_training
+ * and #mkldnn_forward_inference), @p alg_kind, memory descriptors, @p strides,
+ * @p dilates, @p padding_l, @p padding_r, and @p padding_kind. In order to
+ * create a dilated deconvolution without bias, @p bias_desc should either be
+ * @c NULL or point to a descriptor with memory format kind equal
+ * #mkldnn_format_kind_undef.
+ *
+ * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric.
+ *
+ * @note Memory descriptors are allowed to be initialized with
+ *       #mkldnn_format_kind_any value of @p format_kind.
+ *
+ * Inputs:
+ *  - src (#mkldnn_query_src_md, 0)
+ *  - weights (#mkldnn_query_weights_md, 0)
+ *  - bias (#mkldnn_query_weights_md, 1), if created with bias
+ *
+ * Outputs:
+ *  - dst (#mkldnn_query_dst_md, 0)
+ */
+mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_forward_desc_init(
+        mkldnn_deconvolution_desc_t *conv_desc, mkldnn_prop_kind_t prop_kind,
+        mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc,
+        const mkldnn_memory_desc_t *weights_desc,
+        const mkldnn_memory_desc_t *bias_desc,
+        const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides,
+        const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l,
+        const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
+
+/** Initializes a deconvolution descriptor @p conv_desc for backward propagation
+ * with respect to data using @p alg_kind, memory descriptors, @p strides, @p
+ * padding_l, @p padding_r, and @p padding_kind.
+ *
+ * @note Memory descriptors are allowed to be initialized with
+ *       #mkldnn_format_kind_any value of @p format_kind.
+ *
+ * Inputs:
+ *  - diff_dst (#mkldnn_query_diff_dst_md, 0)
+ *  - weights (#mkldnn_query_weights_md, 0)
+ *
+ * Outputs:
+ *  - diff_src (#mkldnn_query_diff_src_md, 0)
+ */
+mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_data_desc_init(
+        mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
+        const mkldnn_memory_desc_t *diff_src_desc,
+        const mkldnn_memory_desc_t *weights_desc,
+        const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
+        const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r,
+        mkldnn_padding_kind_t padding_kind);
+
+/** Initializes a dilated deconvolution descriptor @p conv_desc for backward
+ * propagation with respect to data using @p alg_kind, memory descriptors, @p
+ * strides, @p dilates, @p padding_l, @p padding_r, and @p padding_kind.
+ *
+ * @note Memory descriptors are allowed to be initialized with
+ *       #mkldnn_format_kind_any value of @p format_kind.
+ *
+ * Inputs:
+ *  - diff_dst (#mkldnn_query_diff_dst_md, 0)
+ *  - weights (#mkldnn_query_weights_md, 0)
+ *
+ * Outputs:
+ *  - diff_src (#mkldnn_query_diff_src_md, 0)
+ */
+mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_data_desc_init(
+        mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
+        const mkldnn_memory_desc_t *diff_src_desc,
+        const mkldnn_memory_desc_t *weights_desc,
+        const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
+        const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l,
+        const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
+
+/** Initializes a deconvolution descriptor @p conv_desc for backward propagation
+ * with respect to weights using @p alg_kind, memory descriptors, @p strides,
+ * @p padding_l, @p padding_r, and @p padding_kind.
+ *
+ * @note Memory descriptors are allowed to be initialized with
+ *       #mkldnn_format_kind_any value of @p format_kind.
+ *
+ * Inputs:
+ *  - src (#mkldnn_query_src_md, 0)
+ *  - diff_dst (#mkldnn_query_diff_dst_md, 0)
+ *
+ * Outputs:
+ *  - diff_weights (#mkldnn_query_diff_weights_md, 0)
+ *  - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias
+ */
+mkldnn_status_t MKLDNN_API mkldnn_deconvolution_backward_weights_desc_init(
+        mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
+        const mkldnn_memory_desc_t *src_desc,
+        const mkldnn_memory_desc_t *diff_weights_desc,
+        const mkldnn_memory_desc_t *diff_bias_desc,
+        const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
+        const mkldnn_dims_t padding_l, const mkldnn_dims_t padding_r,
+        mkldnn_padding_kind_t padding_kind);
+
+/** Initializes a dilated deconvolution descriptor @p conv_desc for backward
+ * propagation with respect to weights using @p alg_kind, memory descriptors,
+ * @p strides, @p dilates, @p padding_l, @p padding_r, and @p padding_kind.
+ *
+ * @note Memory descriptors are allowed to be initialized with
+ *       #mkldnn_format_kind_any value of @p format_kind.
+ *
+ * Inputs:
+ *  - src (#mkldnn_query_src_md, 0)
+ *  - diff_dst (#mkldnn_query_diff_dst_md, 0)
+ *
+ * Outputs:
+ *  - diff_weights (#mkldnn_query_diff_weights_md, 0)
+ *  - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias
+ */
+mkldnn_status_t MKLDNN_API mkldnn_dilated_deconvolution_backward_weights_desc_init(
+        mkldnn_deconvolution_desc_t *conv_desc, mkldnn_alg_kind_t alg_kind,
+        const mkldnn_memory_desc_t *src_desc,
+        const mkldnn_memory_desc_t *diff_weights_desc,
+        const mkldnn_memory_desc_t *diff_bias_desc,
+        const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
+        const mkldnn_dims_t dilates, const mkldnn_dims_t padding_l,
+        const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
+
+/** @} */
+
+/** @addtogroup c_api_shuffle Shuffle
+ * A primitive to shuffle data along the axis.
+ * @{ */
+
+/** Initializes a @p shuffle_desc for forward propagation using @p prop_kind,
+ * memory descriptor @p data_desc, @p axis, and @p group_size.
+ *
+ * Inputs:
+ *  - src (#mkldnn_query_src_md, 0)
+ *
+ * Outputs:
+ *  - dst (#mkldnn_query_dst_md, 0)
+ *
+ */
+mkldnn_status_t MKLDNN_API mkldnn_shuffle_forward_desc_init(
+        mkldnn_shuffle_desc_t *shuffle_desc, mkldnn_prop_kind_t prop_kind,
+        const mkldnn_memory_desc_t *data_desc, int axis,
+        mkldnn_dim_t group_size);
+
+/** Initializes a @p shuffle_desc for backward propagation using memory
+ * descriptor @p diff_data_desc, @p axis, and @p group_size.
+ *
+ *
+ * Inputs:
+ *  - diff_dst (#mkldnn_query_diff_dst_md, 0)
+ *
+ * Outputs:
+ *  - diff_src (#mkldnn_query_diff_src_md, 0)
+ *
+ */
+mkldnn_status_t MKLDNN_API mkldnn_shuffle_backward_desc_init(
+        mkldnn_shuffle_desc_t *shuffle_desc,
+        const mkldnn_memory_desc_t *diff_data_desc, int axis,
+        mkldnn_dim_t group_size);
+
+/** @} */
+
+/** @addtogroup c_api_eltwise Eltwise
+ * A primitive to compute element-wise operations like parametric rectifier
+ * linear unit (ReLU).
+ *
+ * Both forward and backward passes support in-place operation; that is, src
+ * and dst point to the same memory for forward pass, and diff_dst and diff_src
+ * point to the same memory for backward pass.
+ *
+ * @warning Because the original src is required for backward pass, in-place
+ * forward pass in general cannot be applied during training. However, for some
+ * kinds of element-wise operations (namely ReLU with alpha parameter equals 0),
+ * dst and src can be interchangeable for the backward pass, which enables
+ * performing in-place forward even for training.
+ *
+ * @{ */
+
+/** Initializes an @p eltwise_desc for forward propagation using @p prop_kind
+ * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference),
+ * @p alg_kind algorithm, memory descriptor @p data_desc, @p alpha, and
+ * @p beta parameters.
+ *
+ * @sa mkldnn_eltwise_desc_t for details.
+ *
+ * Inputs:
+ *  - src (#mkldnn_query_src_md, 0)
+ *
+ * Outputs:
+ *  - dst (#mkldnn_query_dst_md, 0)
+ */
+mkldnn_status_t MKLDNN_API mkldnn_eltwise_forward_desc_init(
+        mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_prop_kind_t prop_kind,
+        mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc,
+        float alpha, float beta);
+
+/** Initializes an @p eltwise_desc for backward propagation using @p alg_kind
+ * algorithm memory descriptors @p diff_data_desc and @p data_desc, and the
+ * @p alpha and @p beta parameters.
+ *
+ * @sa mkldnn_eltwise_desc_t for details.
+ *
+ * Inputs:
+ *  - src (#mkldnn_query_src_md, 0)
+ *  - diff_dst (#mkldnn_query_diff_dst_md, 0)
+ *
+ * Outputs:
+ *  - diff_src (#mkldnn_query_diff_src_md, 0)
+ */
+mkldnn_status_t MKLDNN_API mkldnn_eltwise_backward_desc_init(
+        mkldnn_eltwise_desc_t *eltwise_desc, mkldnn_alg_kind_t alg_kind,
+        const mkldnn_memory_desc_t *diff_data_desc,
+        const mkldnn_memory_desc_t *data_desc, float alpha, float beta);
+
+/** @} */
+
+/** @addtogroup c_api_softmax Softmax
+ * A primitive to perform softmax.
+ *
+ * \f[dst[u][c][in] =
+ *    \frac{\exp(src[ou][c][in]) - \max\limits_{c}(src[ou][c][in])}
+ *    {\sum\limits_{c}\{\exp(src[ou][c][in])
+ *    - \max\limits_{c}(src[ou][c][in])\}},\f]
+ *
+ * where \f$ou, iu\f$ are outer and inner sizes repectively, defined
+ * by @p data_desc.dims and @p softmax_axis.
+ * @{ */
+
+/** Initializes a @p softmax_desc for forward propagation using @p prop_kind
+ * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference)
+ * and memory descriptor @p data_desc.
+ *
+ * Inputs:
+ *  - src (#mkldnn_query_src_md, 0)
+ *
+ * Outputs:
+ *  - dst (#mkldnn_query_dst_md, 0)
+ */
+mkldnn_status_t MKLDNN_API mkldnn_softmax_forward_desc_init(
+        mkldnn_softmax_desc_t *softmax_desc, mkldnn_prop_kind_t prop_kind,
+        const mkldnn_memory_desc_t *data_desc, int softmax_axis);
+
+/** Initializes a @p softmax_desc for backward propagation using memory
+ * descriptors @p diff_desc and @p data_desc.
+ *
+ * Inputs:
+ *  - dst (#mkldnn_query_dst_md, 0)
+ *  - diff_dst (#mkldnn_query_diff_dst_md, 0)
+ *
+ * Outputs:
+ *  - diff_src (#mkldnn_query_diff_src_md, 0)
+ */
+mkldnn_status_t MKLDNN_API mkldnn_softmax_backward_desc_init(
+        mkldnn_softmax_desc_t *softmax_desc,
+        const mkldnn_memory_desc_t *diff_desc,
+        const mkldnn_memory_desc_t *data_desc, int softmax_axis);
+
+/** @} */
+
+/** @addtogroup c_api_pooling Pooling
+ * A primitive to perform max or average pooling.
+ *
+ * Max pooling:
+ * \f[dst[n][oc][oh][ow] =
+ *     \max\limits_{kw,kh}
+ *     (src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw]),\f]
+ *
+ * Average pooling:
+ * \f[dst[n][oc][oh][ow] =
+ *     \frac{1}{KW \cdot KH}\sum\limits_{kw,kh}
+ *     src[n][ic][oh \cdot s_h - p_l[0] + kh][ow \cdot s_w - p_r[1] + kw],\f]
+ *
+ * where \f$p_l, p_r\f$ are @p padding_l and @p padding_r respectively, and
+ * output spatial dimensions are calculated similarly to how they are done in
+ * convolution.
+ *
+ * During training, max pooling requires a workspace on forward
+ * (#mkldnn_forward_training) and backward (#mkldnn_backward) passes to
+ * save indices where maximum was found. The workspace layout is opaque, and
+ * the indices cannot be restored from it. However, one can use backward
+ * pooling to perform up-sampling (used in some detection topologies).
+ *
+ * @{ */
+
+/** Initializes a pooling descriptor @p pool_desc for forward propagation using
+ * @p prop_kind (possible values are #mkldnn_forward_training and
+ * #mkldnn_forward_inference), @p alg_kind, memory descriptors, and pooling
+ * parameters in the spatial domain: @p strides, @p kernel sizes, @p padding_l,
+ * @p padding_r, and @p padding_kind.
+ *
+ * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric.
+ *
+ * Inputs:
+ *  - src (#mkldnn_query_src_md, 0)
+ *
+ * Outputs:
+ *  - dst (#mkldnn_query_dst_md, 0)
+ *  - workspace (#mkldnn_query_workspace_md, 0),
+ *      if @p alg_kind = #mkldnn_pooling_max and
+ *      @p prop_kind = #mkldnn_forward_training
+ */
+mkldnn_status_t MKLDNN_API mkldnn_pooling_forward_desc_init(
+        mkldnn_pooling_desc_t *pool_desc, mkldnn_prop_kind_t prop_kind,
+        mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *src_desc,
+        const mkldnn_memory_desc_t *dst_desc, const mkldnn_dims_t strides,
+        const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l,
+        const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
+
+/** Initializes a pooling descriptor @p pool_desc for backward propagation
+ * using @p alg_kind, memory descriptors, and pooling parameters in the spatial
+ * domain: @p strides, @p kernel sizes, @p padding_l, @p padding_r, and @p
+ * padding_kind.
+ *
+ * @note If @p padding_r is @c NULL, the padding is supposed to be symmetric.
+ *
+ * Inputs:
+ *  - diff_dst (#mkldnn_query_diff_dst_md, 0)
+ *  - workspace (#mkldnn_query_workspace_md, 0),
+ *      if @p alg_kind = #mkldnn_pooling_max
+ *
+ * Outputs:
+ *  - diff_src (#mkldnn_query_diff_src_md, 0)
+ */
+mkldnn_status_t MKLDNN_API mkldnn_pooling_backward_desc_init(
+        mkldnn_pooling_desc_t *pool_desc, mkldnn_alg_kind_t alg_kind,
+        const mkldnn_memory_desc_t *diff_src_desc,
+        const mkldnn_memory_desc_t *diff_dst_desc, const mkldnn_dims_t strides,
+        const mkldnn_dims_t kernel, const mkldnn_dims_t padding_l,
+        const mkldnn_dims_t padding_r, mkldnn_padding_kind_t padding_kind);
+
+/** @} */
+
+/** @addtogroup c_api_lrn LRN
+ * A primitive to perform local response normalization (LRN) across or within
+ * channels.
+ *
+ * LRN accross channels:
+ * \f[dst[n][c][h][w] = \left\{k + \frac{\alpha}{n_{l}}
+ *                      \sum\limits_{i=-(n_{l}-1)/2}^{(n_{l}+1)/2}
+ *                      (src[n][c+i][h][w])^2\right\}^{-\beta}
+ *                      src[n][c][h][w],\f]
+ *
+ * LRN within channels:
+ * \f[dst[n][c][h][w] = \left\{k + \frac{\alpha}{n_{l}}
+ *                      \sum\limits_{i=-(n_{l}-1)/2}^{(n_{l}+1)/2}
+ *                      (src[n][c][h+i][w+i])^2\right\}^{-\beta}
+ *                      src[n][c][h][w],\f]
+ *
+ * where \f$n_{l}\f$ is the @p local_size.
+ *
+ * During training, LRN might or might not require a workspace on forward
+ * (#mkldnn_forward_training) and backward (#mkldnn_backward) passes. The
+ * behavior is implementation specific. Optimized implementations typically
+ * require a workspace and use it to save some intermediate results from the
+ * forward pass that accelerate computations on the backward pass.
+ *
+ * To check whether a workspace is required, query the LRN primitive descriptor
+ * for the workspace (#mkldnn_query_workspace_md). Success indicates that the
+ * workspace is required and its description will be returned.
+ * @sa mkldnn_primitive_desc_query and mkldnn_primitive_desc_query_pd
+ *
+ * @{ */
+
+/** Initializes an @p lrn_desc for forward propagation using @p prop_kind
+ * (possible values are #mkldnn_forward_training and #mkldnn_forward_inference),
+ * @p alg_kind, memory descriptor @p data_desc, and regularization
+ * parameters @p local_size, @p alpha, @p beta, and @p k.
+ *
+ * Inputs:
+ *  - src (#mkldnn_query_src_md, 0)
+ *
+ * Outputs:
+ *  - dst (#mkldnn_query_dst_md, 0)
+ *  - workspace (#mkldnn_query_workspace_md, 0),
+ *      if the underlying implementation requires
+ */
+mkldnn_status_t MKLDNN_API mkldnn_lrn_forward_desc_init(
+        mkldnn_lrn_desc_t *lrn_desc, mkldnn_prop_kind_t prop_kind,
+        mkldnn_alg_kind_t alg_kind, const mkldnn_memory_desc_t *data_desc,
+        mkldnn_dim_t local_size, float alpha, float beta, float k);
+
+/** Initializes an @p lrn_desc for backward propagation using @p alg_kind,
+ * memory descriptors @p data_desc and @p diff_data_desc, and regularization
+ * parameters @p local_size, @p alpha, @p beta, and @p k.
+ *
+ * Inputs:
+ *  - src (#mkldnn_query_src_md, 0)
+ *  - diff_dst (#mkldnn_query_diff_dst_md, 0)
+ *  - workspace (#mkldnn_query_workspace_md, 0),
+ *      if the underlying implementation requires
+ *
+ * Outputs:
+ *  - diff_src (#mkldnn_query_diff_src_md, 0)
+ */
+mkldnn_status_t MKLDNN_API mkldnn_lrn_backward_desc_init(
+        mkldnn_lrn_desc_t *lrn_desc, mkldnn_alg_kind_t alg_kind,
+        const mkldnn_memory_desc_t *diff_data_desc,
+        const mkldnn_memory_desc_t *data_desc, mkldnn_dim_t local_size,
+        float alpha, float beta, float k);
+
+/** @} */
+
+/** @addtogroup c_api_batch_normalization Batch Normalization
+ * A primitive to perform batch normalization.
+ *
+ * \f[dst[n][c][h][w] = \gamma[c] \frac{src[n][c][h][w] - \mu[c]}
+ *                      {\sqrt{\sigma[c] + eps}} + \beta[c],\f]
+ *
+ * where \f$\gamma[c], \beta[c]\f$ are weights and bias for a channel and,
+ *
+ * \f$\mu[c] = \frac{1}{NHW} \sum\limits_{whn} src[n][c][h][w]\f$,
+ * \f$\sigma[c] = \frac{1}{NHW} \sum\limits_{whn}
+ *                              (src[n][c][h][w] - \mu[c])^2\f$,
+ *
+ * and @c eps is a constant to improve numerical stability.
+ *
+ * Both forward and backward passes support in-place operation; that is, src
+ * and dst point to the same memory for forward pass, and diff_dst and diff_src
+ * point to the same memory for backward pass.
+ *
+ * Batch normalization supports different flavors controlled by
+ * mkldnn_batch_normalization_desc_t. For example, batch normalization can
+ * compute the mean and variance on its own or take them as inputs. It can
+ * either perform scaling and shifting using gamma and beta parameters or not.
+ * Optionally it can also perform a fused ReLU, which in case of training would
+ * also require a workspace.
+ *
+ * @sa mkldnn_batch_normalization_desc_t
+ * @{ */
+
+/** Initializes a batch normalization descriptor @p bnrm_desc for forward
+ * propagation using @p prop_kind (possible values are
+ * #mkldnn_forward_training and #mkldnn_forward_inference), memory descriptor
+ * @p data_desc, normalization parameter @p epsilon, and @p flags set using bit
+ * flags of type mkldnn_batch_normalization_desc_t.
+ *
+ * Inputs:
+ *  - src (#mkldnn_query_src_md, 0)
+ *  - mean (#mkldnn_query_src_md, 1),
+ *      if #mkldnn_use_global_stats bit-flags is set in @p flags
+ *  - variance (#mkldnn_query_src_md, 2),
+ *      if #mkldnn_use_global_stats bit-flags is set in @p flags
+ *  - scale_and_shift (#mkldnn_query_weights_md, 0),
+ *      if #mkldnn_use_scaleshift bit-flags is set in @p flags
+ *
+ * Outputs:
+ *  - dst (#mkldnn_query_dst_md, 0)
+ *  - mean (#mkldnn_query_dst_md, 1),
+ *      if #mkldnn_use_global_stats bit-flags is not set in @p flags
+ *      @p prop_kind = #mkldnn_forward_training
+ *  - variance (#mkldnn_query_dst_md, 2),
+ *      if #mkldnn_use_global_stats bit-flags is not set in @p flags
+ *      and @p prop_kind = #mkldnn_forward_training
+ *  - workspace (#mkldnn_query_workspace_md, 0),
+ *      if #mkldnn_fuse_bn_relu bit-flags is set in @p flags
+ *      and @p prop_kind = #mkldnn_forward_training
+ *
+ * @note In-place operation is supported; that is, dst points to the same memory
+ *       as src.
+ *
+ * @sa mkldnn_batch_normalization_desc_t
+ */
+mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_forward_desc_init(
+        mkldnn_batch_normalization_desc_t *bnrm_desc,
+        mkldnn_prop_kind_t prop_kind, const mkldnn_memory_desc_t *data_desc,
+        float epsilon, unsigned flags);
+
+/** Initializes a batch normalization descriptor @p bnrm_desc for backward
+ * propagation with respect to data and scale-shift parameters using memory
+ * descriptors @p data_desc and @p diff_data_desc, normalization parameter
+ * @p epsilon, and @p flags set using bit flags of type
+ * mkldnn_batch_normalization_desc_t.
+ *
+ * Inputs:
+ *  - src (#mkldnn_query_src_md, 0)
+ *  - mean (#mkldnn_query_src_md, 1)
+ *  - variance (#mkldnn_query_src_md, 2)
+ *  - diff_dst (#mkldnn_query_diff_dst_md, 0)
+ *  - scale_and_shift (#mkldnn_query_weights_md, 0),
+ *      if #mkldnn_use_scaleshift bit-flags is set in @p flags
+ *  - workspace (#mkldnn_query_workspace_md, 0),
+ *      if #mkldnn_fuse_bn_relu bit-flags is set in @p flags
+ *
+ * Outputs:
+ *  - diff_src (#mkldnn_query_diff_src_md, 0)
+ *  - diff_scale_and_shift (#mkldnn_query_diff_weights_md, 0),
+ *      if #mkldnn_use_scaleshift bit-flags is set in @p flags
+ *      and @p prop_kind = #mkldnn_backward
+ *
+ * @note in-place operation is supported,
+ *       i.e. diff_src points to the same memory as diff_dst.
+ *
+ * @sa mkldnn_batch_normalization_desc_t
+ */
+mkldnn_status_t MKLDNN_API mkldnn_batch_normalization_backward_desc_init(
+        mkldnn_batch_normalization_desc_t *bnrm_desc,
+        mkldnn_prop_kind_t prop_kind,
+        const mkldnn_memory_desc_t *diff_data_desc,
+        const mkldnn_memory_desc_t *data_desc,
+        float epsilon, unsigned flags);
+
+/** @} */
+
+/** @addtogroup c_api_inner_product Inner product
+ * A primitive to compute an inner product.
+ *
+ * Inner product layer is also known as fully connected layer.
+ * With spatial dimension:
+ *
+ * \f[dst[n][oc] = \sum\limits_{ic, kh, kw}
+ *                 src[n][ic][kh][kw] \cdot weights[oc][ic][kh][kw]
+ *                 + bias[oc]\f]
+ * @{ */
+
+/** Initializes an inner product descriptor @p ip_desc for forward propagation
+ * using @p prop_kind (possible values are #mkldnn_forward_training and
+ * #mkldnn_forward_inference) and memory descriptors. In order to create an
+ * inner product without bias, @p bias_desc should be either @c NULL or a
+ * pointer to a descriptor with memory format kind equals
+ * #mkldnn_format_kind_undef.
+ *
+ * @note Memory descriptors are allowed to be initialized with
+ *       #mkldnn_format_kind_any value of @p format_kind.
+ *
+ * Inputs:
+ *  - src (#mkldnn_query_src_md, 0)
+ *  - weights (#mkldnn_query_weights_md, 0)
+ *  - bias (#mkldnn_query_weights_md, 1), if created with bias
+ *
+ * Outputs:
+ *  - dst (#mkldnn_query_dst_md, 0)
+ */
+mkldnn_status_t MKLDNN_API mkldnn_inner_product_forward_desc_init(
+        mkldnn_inner_product_desc_t *ip_desc, mkldnn_prop_kind_t prop_kind,
+        const mkldnn_memory_desc_t *src_desc,
+        const mkldnn_memory_desc_t *weights_desc,
+        const mkldnn_memory_desc_t *bias_desc,
+        const mkldnn_memory_desc_t *dst_desc);
+
+/** Initializes an inner product descriptor @p ip_desc for backward propagation
+ * with respect to data using memory descriptors.
+ *
+ * @note Memory descriptors are allowed to be initialized with
+ *       #mkldnn_format_kind_any value of @p format_kind.
+ *
+ * Inputs:
+ *  - diff_dst (#mkldnn_query_diff_dst_md, 0)
+ *  - weights (#mkldnn_query_weights_md, 0)
+ *
+ * Outputs:
+ *  - diff_src (#mkldnn_query_diff_src_md, 0)
+ */
+mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_data_desc_init(
+        mkldnn_inner_product_desc_t *ip_desc,
+        const mkldnn_memory_desc_t *diff_src_desc,
+        const mkldnn_memory_desc_t *weights_desc,
+        const mkldnn_memory_desc_t *diff_dst_desc);
+
+/** Initializes an inner product descriptor @p ip_desc for backward propagation
+ * with respect to weights using memory descriptors.
+ *
+ * @note Memory descriptors are allowed to be initialized with
+ *       #mkldnn_format_kind_any value of @p format_kind.
+ *
+ * Inputs:
+ *  - src (#mkldnn_query_src_md, 0)
+ *  - diff_dst (#mkldnn_query_diff_dst_md, 0)
+ *
+ * Outputs:
+ *  - diff_weights (#mkldnn_query_diff_weights_md, 0)
+ *  - diff_bias (#mkldnn_query_diff_weights_md, 1), if created with bias
+ */
+mkldnn_status_t MKLDNN_API mkldnn_inner_product_backward_weights_desc_init(
+        mkldnn_inner_product_desc_t *ip_desc,
+        const mkldnn_memory_desc_t *src_desc,
+        const mkldnn_memory_desc_t *diff_weights_desc,
+        const mkldnn_memory_desc_t *diff_bias_desc,
+        const mkldnn_memory_desc_t *diff_dst_desc);
+
+/** @} */
+
+/** @addtogroup c_api_rnn RNN
+ * A primitive to compute the common recurrent layer.
+ * @todo add additional description for the group
+ * @{ */
+
+/**
+ * Initializes a recurrent cell descriptor @p rnn_cell_desc
+ * using @p rnn_cell_desc, @p kind (possible values are
+ *  #mkldnn_vanilla_rnn, #mkldnn_vanilla_lstm, #mkldnn_vanilla_gru, and
+ *  #mkldnn_gru_linear_before_reset),
+ *  @p f (possible values are #mkldnn_eltwise_relu and
+ *   #mkldnn_eltwise_tanh), @p flags, @p alpha, and @p clipping.
+ */
+mkldnn_status_t MKLDNN_API mkldnn_rnn_cell_desc_init(
+        mkldnn_rnn_cell_desc_t *rnn_cell_desc,
+        mkldnn_alg_kind_t kind, mkldnn_alg_kind_t f,
+        unsigned int flags, float alpha, float clipping);
+
+/** Returns the number of gates of a particular @p rnn_cell_desc. */
+int MKLDNN_API mkldnn_rnn_cell_get_gates_count(
+        const mkldnn_rnn_cell_desc_t *rnn_cell_desc);
+
+/** Returns the number of states of a particular @p rnn_cell_desc. */
+int MKLDNN_API mkldnn_rnn_cell_get_states_count(
+        const mkldnn_rnn_cell_desc_t *rnn_cell_desc);
+
+/** Sets quantization @p scale and @p shift for RNN data tensors.
+ *  For performance reasons, low precision configuration of RNN primitive
+ *  expects input activations to have unsigned int8 data type. Scale and shift
+ *  used to quantize floating point data to unsigned integer must be passed to
+ *  RNN primitive using attributes.
+ *  Example usage:
+ * @code
+ *      // rnn parameters
+ *      int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32;
+ *      // activations quantization parameters
+ *      float scale = ..., shift = ..;
+ *
+ *      mkldnn_primitive_attr_t rnn_attr;
+ *      // create default attributes
+ *      mkldnn_primitive_attr_create(&rnn_attr);
+ *
+ *      // set scale and shift for int8 quantization of activation
+ *      mkldnn_primitive_attr_set_rnn_data_qparams(rnn_attr, scale, shift);
+ *
+ *      // create & configure rnn op_desc
+ *      mkldnn_rnn_desc_t rnn_d;
+ *      mkldnn_primitive_desc_t rnn_pd;
+ *      mkldnn_primitive_desc_create(&rnn_pd, &rnn_d, attr, engine, NULL);
+ * @endcode
+ * @note
+ *      Quantization scale and shift are common for src_layer, src_iter,
+ *      dst_iter and dst_layer.
+ */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_data_qparams(
+        mkldnn_primitive_attr_t attr, const float scale, const float shift);
+
+/** Sets quantization scales @p weights_scales for RNN weights tensors.
+ * Low precision configuration of RNN primitive expects input weights to have
+ * signed int8 data type. Scales used to quantize floating point data
+ * to signed integer must be passed to RNN primitive using attributes.
+ * The @p mask argument defines correspondence between output tensor dimensions
+ * and the @p weights_scales array. Set i-th bit of @p mask to 1 to use
+ * dedicated scaling factor for each slice of the output tensor over i-th
+ * dimension. Set @p mask to 0 to use common scaling factor for the whole output
+ * tensor. Example usage:
+ * @code
+ *      // rnn parameters
+ *      int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32;
+ *      // unique output scales per output channel
+ *      float weights_scales[dic * n_gates] = { ... };
+ *      // mask that specifies last two dimensions of ldigo format
+ *      int mask = 0x3;
+ *
+ *      mkldnn_primitive_attr_t attr;
+ *      // create default attributes
+ *      mkldnn_primitive_attr_create(&attr);
+ *
+ *      // set output channel-wise weights scales
+ *      mkldnn_primitive_attr_set_rnn_weights_qparams(attr, dic * n_gates, mask,
+ *              weights_scales);
+ *
+ *      // create & configure rnn op_desc
+ *      mkldnn_rnn_desc_t rnn_d;
+ *      mkldnn_primitive_desc_t rnn_pd;
+ *      mkldnn_primitive_desc_create(&rnn_pd, &rnn_d, attr, engine, NULL);
+ * @endcode
+ * @note
+ *      The dimension order is always native and does not depend on the actual
+ *      layout used. For example, 5 dimensional weights always have
+ *      (l, d, i, g, o) logical dimension ordering.
+ * @note
+ *      Quantization sales are common for weights_layer and weights_iteration
+ * @note
+ *      There is no way to check that @p count corresponds to @p mask until an
+ *      actual primitive descriptor is created, so it is user's responsibility
+ *      to set proper values. The following formula must be held:
+ *
+ *      \f[count = \prod\limits_{d \in mask} output.dims[d]\f]
+ */
+mkldnn_status_t MKLDNN_API mkldnn_primitive_attr_set_rnn_weights_qparams (
+        mkldnn_primitive_attr_t attr, mkldnn_dim_t count, int mask,
+                const float *weights_scales);
+
+/** Initializes a rnn descriptor @p rnn_desc for forward propagation
+ * using @p prop_kind, @p rnn_cell_desc, @p direction, and memory descriptors.
+ * @note If @p prop_kind equals #mkldnn_forward_training, you must query a
+ * workspace memory descriptor before creating the primitive.
+ *
+ * @p src_iter_desc, @p bias_desc, and @p dst_iter_desc are allowed to either be
+ * @c NULL or point to a zero memory descriptor, which would indicate that the
+ * RNN primitive should not use them.
+ *
+ * @note All memory descriptors except @p src_iter_desc are allowed to be
+ *       initialized with #mkldnn_format_kind_any value of @p format_kind.
+ *
+ * Inputs:
+ *  - src_layer (#mkldnn_query_src_md, 0)
+ *  - src_iter (#mkldnn_query_src_md, 1), if used
+ *  - weights_layer (#mkldnn_query_weights_md, 0)
+ *  - weights_iter (#mkldnn_query_weights_md, 1)
+ *  - bias (#mkldnn_query_weights_md, 2), if used
+ *
+ * Outputs:
+ *  - dst_layer (#mkldnn_query_dst_md, 0)
+ *  - dst_iter (#mkldnn_query_dst_md, 1), if used
+ *  - workspace (#mkldnn_query_workspace_md, 0),
+ *      if @p prop_kind equals #mkldnn_forward_training
+ */
+mkldnn_status_t MKLDNN_API mkldnn_rnn_forward_desc_init(
+        mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind,
+        const mkldnn_rnn_cell_desc_t *rnn_cell_desc,
+        const mkldnn_rnn_direction_t direction,
+        const mkldnn_memory_desc_t *src_layer_desc,
+        const mkldnn_memory_desc_t *src_iter_desc,
+        const mkldnn_memory_desc_t *weights_layer_desc,
+        const mkldnn_memory_desc_t *weights_iter_desc,
+        const mkldnn_memory_desc_t *bias_desc,
+        const mkldnn_memory_desc_t *dst_layer_desc,
+        const mkldnn_memory_desc_t *dst_iter_desc);
+
+/** Initializes a rnn descriptor @p rnn_desc for backward propagation
+ * using @p prop_kind, @p rnn_cell_desc, @p direction, and memory descriptors.
+ *
+ * @note All memory descriptors are allowed to be initialized with
+ *       #mkldnn_format_kind_any value of @p format_kind.
+ *
+ * @p src_iter_desc (simultaneously with @p diff_src_iter_desc),
+ * @p bias_desc (simultaneously with @p diff_bias_desc), and
+ * @p dst_iter_desc (simultaneously with @p diff_src_iter_desc) are allowed to
+ * either be @c NULL or point to a zero memory descriptor, which would indicate
+ * that the RNN primitive should not use them.
+ *
+ * Inputs:
+ *  - src_layer (#mkldnn_query_src_md, 0)
+ *  - src_iter (#mkldnn_query_src_md, 1), if used
+ *  - weights_layer (#mkldnn_query_weights_md, 0)
+ *  - weights_iter (#mkldnn_query_weights_md, 1)
+ *  - bias (#mkldnn_query_weights_md, 2), if used
+ *  - dst_layer (#mkldnn_query_dst_md, 0)
+ *  - dst_iter (#mkldnn_query_dst_md, 1), if used
+ *  - diff_dst_layer (#mkldnn_query_diff_dst_md, 0)
+ *  - diff_dst_iter (#mkldnn_query_diff_dst_md, 1), if used
+ *  - workspace (#mkldnn_query_workspace_md, 0)
+ *
+ * Outputs:
+ *  - diff_src_layer (#mkldnn_query_diff_src_md, 0)
+ *  - diff_src_iter (#mkldnn_query_diff_src_md, 1), if used
+ *  - diff_weights_layer (#mkldnn_query_diff_weights_md, 0)
+ *  - diff_weights_iter (#mkldnn_query_diff_weights_md, 1)
+ *  - diff_bias (#mkldnn_query_diff_weights_md, 2), if used
+ */
+mkldnn_status_t MKLDNN_API mkldnn_rnn_backward_desc_init(
+        mkldnn_rnn_desc_t *rnn_desc, mkldnn_prop_kind_t prop_kind,
+        const mkldnn_rnn_cell_desc_t *rnn_cell_desc,
+        const mkldnn_rnn_direction_t direction,
+        const mkldnn_memory_desc_t *src_layer_desc,
+        const mkldnn_memory_desc_t *src_iter_desc,
+        const mkldnn_memory_desc_t *weights_layer_desc,
+        const mkldnn_memory_desc_t *weights_iter_desc,
+        const mkldnn_memory_desc_t *bias_desc,
+        const mkldnn_memory_desc_t *dst_layer_desc,
+        const mkldnn_memory_desc_t *dst_iter_desc,
+        const mkldnn_memory_desc_t *diff_src_layer_desc,
+        const mkldnn_memory_desc_t *diff_src_iter_desc,
+        const mkldnn_memory_desc_t *diff_weights_layer_desc,
+        const mkldnn_memory_desc_t *diff_weights_iter_desc,
+        const mkldnn_memory_desc_t *diff_bias_desc,
+        const mkldnn_memory_desc_t *diff_dst_layer,
+        const mkldnn_memory_desc_t *diff_dst_iter_desc);
+
+/** @} */
+
+/** @} */
+
+/** @addtogroup c_api_engine Engine operations
+ * @{ */
+
+/** Returns the number of engines of a particular @p kind. */
+size_t MKLDNN_API mkldnn_engine_get_count(mkldnn_engine_kind_t kind);
+
+/** Creates an @p engine of particular @p kind and @p index. */
+mkldnn_status_t MKLDNN_API mkldnn_engine_create(mkldnn_engine_t *engine,
+        mkldnn_engine_kind_t kind, size_t index);
+
+/** Returns the kind of an @p engine. */
+mkldnn_status_t MKLDNN_API mkldnn_engine_get_kind(mkldnn_engine_t engine,
+        mkldnn_engine_kind_t *kind);
+
+/** Destroys an @p engine. */
+mkldnn_status_t MKLDNN_API mkldnn_engine_destroy(mkldnn_engine_t engine);
+
+/** @} */
+
+/** @addtogroup c_api_stream Execution stream operations
+ * @{ */
+
+/** Creates an execution @p stream for @p engine and with @p flags. */
+mkldnn_status_t MKLDNN_API mkldnn_stream_create(mkldnn_stream_t *stream,
+        mkldnn_engine_t engine, unsigned flags);
+
+/** Destroys an execution @p stream. */
+mkldnn_status_t MKLDNN_API mkldnn_stream_destroy(mkldnn_stream_t stream);
+
+/** @} */
+
+/** @addtogroup c_api_service Service functions
+ * @{ */
+
+/** Sets verbosity level (print information to stdout).
+ * Possible levels are:
+ *  - 0 -- no verbose output (default)
+ *  - 1 -- primitive information at execution
+ *  - 2 -- primitive information at creation and execution
+ *
+ * @note
+ *     Dumping information might affect performance.
+ *     This setting overrides the MKLDNN_VERBOSE environment variable. */
+mkldnn_status_t MKLDNN_API mkldnn_set_verbose(int level);
+
+/** Enables or disables dumping of JIT-generated code.
+ * The enable parameter can be:
+ *  - 0 -- disable
+ *  - any other value -- enable
+ *
+ * @note
+ *     This setting overrides the MKLDNN_JIT_DUMP environment variable. */
+mkldnn_status_t MKLDNN_API mkldnn_set_jit_dump(int enable);
+
+/** Gets library version information.
+ * Version information includes:
+ *  - major -- major version number
+ *  - minor -- minor version number
+ *  - patch -- patch release number
+ *  - hash -- git commit hash */
+const mkldnn_version_t MKLDNN_API *mkldnn_version();
+
+/** @} */
+
+/** @addtogroup c_api_blas BLAS functions
+ * A subset of Basic Linear ALgebra (BLAS) functions to perform
+ * matrix-matrix multiplication.
+ * @{ */
+
+/** SGEMM performs a matrix-matrix multiplication operation defined as
+ *
+ * C := alpha*op( A )*op( B ) + beta*C
+ *
+ * where
+ *  - op( X ) is one of op( X ) = X or op( X ) = X**T,
+ *  - alpha and beta are scalars,
+ *  - A, B and C are matrices, with op( A ) an m by k matrix, op( B ) a k by n matrix
+ *    and C an m by n matrix.
+ *
+ * The matrices are assumed to be stored in column-major order (the elements
+ * in a matrix columns are contiguous in memory).
+ *
+ * @note
+ *      The API is different from the standard BLAS routine
+ *      because it returns mkldnn_status_t for error handling.
+ *      XERBLA is not supported: no error message will be printed
+ *      in case of incorrect parameters. */
+mkldnn_status_t MKLDNN_API mkldnn_sgemm(
+        const char *transa, const char *transb,
+        const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K,
+        const float *alpha, const float *A, const mkldnn_dim_t *lda,
+        const float *B, const mkldnn_dim_t *ldb,
+        const float *beta, float *C, const mkldnn_dim_t *ldc);
+
+/** gemm_s8u8s32 and gemm_s8s8s32 perform a matrix-matrix multiplication
+ * operation and add the result to a scalar-matrix product. For the final
+ * result, a vector is added to each row or column of the output matrix.
+ * The operation is defined as:
+ *
+ * C := alpha*(op(A) + A_offset) * (op(B) + B_offset) + beta*C + C_offset
+ *
+ * where
+ *  - op( X ) = X or op( X ) = X**T,
+ *  - A_offset is an m-by-k matrix with every element equal to the value oa,
+ *  - B_offset is an k-by-n matrix with every element equal to the value ob,
+ *  - C_offset is an m-by-n matrix defined by the oc array, size len:
+ *    - if offsetc = F: len must be at least 1
+ *    - if offsetc = C: len must be at least max(1, m)
+ *    - if offsetc = R: len must be at least max(1, n)
+ *  - alpha and beta are scalars, and A, B and C are matrices, with op( A )
+ *    an m-by-k matrix, op( B ) a k-by-n matrix and C an m-by-n matrix.
+ *
+ * The matrices are assumed to be stored in column-major order (the elements
+ * in a matrix columns are contiguous in memory).
+ *
+ * @note
+ *      The API is different compared with the standard BLAS routine
+ *      because it returns mkldnn_status_t for error handling.
+ *      XERBLA is not supported: no error message will be printed
+ *      in case of incorrect parameters. */
+mkldnn_status_t MKLDNN_API mkldnn_gemm_s8u8s32(
+        const char *transa, const char *transb, const char *offsetc,
+        const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K,
+        const float *alpha,
+        const int8_t *A, const mkldnn_dim_t *lda, const int8_t *ao,
+        const uint8_t *B, const mkldnn_dim_t *ldb, const int8_t *bo,
+        const float *beta,
+        int32_t *c, const mkldnn_dim_t *ldc, const int32_t *co);
+
+mkldnn_status_t MKLDNN_API mkldnn_gemm_s8s8s32(
+        const char *transa, const char *transb, const char *offsetc,
+        const mkldnn_dim_t *M, const mkldnn_dim_t *N, const mkldnn_dim_t *K,
+        const float *alpha,
+        const int8_t *A, const mkldnn_dim_t *lda, const int8_t *ao,
+        const int8_t *B, const mkldnn_dim_t *ldb, const int8_t *bo,
+        const float *beta,
+        int32_t *c, const mkldnn_dim_t *ldc, const int32_t *co);
+/** @} */
+
+/** @} */
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif

+ 2615 - 0
thirdparty/oidn/mkl-dnn/include/mkldnn.hpp

@@ -0,0 +1,2615 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MKLDNN_HPP
+#define MKLDNN_HPP
+
+#ifndef DOXYGEN_SHOULD_SKIP_THIS
+#include <stdlib.h>
+#include <memory>
+#include <vector>
+#include <unordered_map>
+#include <algorithm>
+#include <iterator>
+
+#include "mkldnn.h"
+#endif
+
+namespace mkldnn {
+
+/// @addtogroup cpp_api C++ API
+/// @{
+
+/// @addtogroup cpp_api_utils Utils
+/// @{
+
+/// A class that provides the destructor for an Intel(R) MKL-DNN C handle
+template <typename T> class handle_traits {};
+
+/// A class for wrapping an Intel(R) MKL-DNN handle. It is used as the base
+/// class for primitive (#mkldnn_primitive_t), engine (#mkldnn_engine_t), and
+/// stream (#mkldnn_stream_t) handles. An object of the #mkldnn::handle class
+/// can be passed by value. This class enables wrapping:
+///  - Newly constructed handles.
+///    @n In this case, the constructed handle uses reference counting provided
+///    by @p std::shared_ptr with a proper deleter function specified through
+///    the @p handle_traits class.
+///  - Pre-existing handles returned by the Intel(R) MKL-DNN C API (for
+///    example, through mkldnn_primitive_get_primitive_desc()).
+///    @n In this case, an Intel(R) MKL-DNN C API handle is wrapped without a
+///    deleter because it is assumed that the handle wrapper for the original
+///    object deletes the handle (this model is similar to @p std::weak_ptr).
+template <typename T, typename traits=handle_traits<T>> class handle {
+private:
+    std::shared_ptr<typename std::remove_pointer<T>::type> _data;
+    handle(const handle &&) = delete;
+    handle &operator=(const handle &&other) = delete;
+protected:
+    bool operator==(const T other) const { return other == _data.get(); }
+    bool operator!=(const T other) const { return !(*this == other); }
+public:
+    /// Constructs a C handle wrapper.
+    /// @param t The C handle to wrap.
+    /// @param weak A flag to specify whether to construct a weak wrapper.
+    handle(T t = 0, bool weak = false): _data(0) {
+        reset(t, weak);
+    }
+
+    handle(const handle &other): _data(other._data) {}
+    handle &operator=(const handle &other) {
+        _data = other._data;
+        return *this;
+    }
+    /// Resets the value of a C handle.
+    /// @param t The new value of the C handle.
+    /// @param weak A flag to specify whether the wrapper should be weak.
+    void reset(T t, bool weak = false) {
+        auto dummy_destructor = [](T) { return decltype(traits::destructor(0))(0); };
+        _data.reset(t, weak ? dummy_destructor : traits::destructor);
+    }
+
+    /// Returns the value of the underlying C handle.
+    T get() const { return _data.get(); }
+
+    bool operator==(const handle &other) const { return other._data.get() == _data.get(); }
+    bool operator!=(const handle &other) const { return !(*this == other); }
+};
+
+#ifndef DOXYGEN_SHOULD_SKIP_THIS
+template <> struct handle_traits<mkldnn_memory_t> {
+    static constexpr auto destructor = &mkldnn_memory_destroy;
+};
+
+template <> struct handle_traits<mkldnn_primitive_desc_t> {
+    static constexpr auto destructor = &mkldnn_primitive_desc_destroy;
+};
+
+template <> struct handle_traits<mkldnn_primitive_t> {
+    static constexpr auto destructor = &mkldnn_primitive_destroy;
+};
+
+template <> struct handle_traits<mkldnn_primitive_desc_iterator_t> {
+    static constexpr auto destructor = &mkldnn_primitive_desc_iterator_destroy;
+};
+#endif
+
+struct memory;
+struct primitive_desc;
+
+/// Base class for all computational primitives.
+class primitive: public handle<mkldnn_primitive_t> {
+    friend struct error;
+    friend struct stream;
+    using handle::handle;
+public:
+    /// A proxy to C primitive kind enum
+    enum class kind {
+        undefined_primitive = mkldnn_undefined_primitive,
+        reorder = mkldnn_reorder,
+        concat = mkldnn_concat,
+        sum = mkldnn_sum,
+        convolution = mkldnn_convolution,
+        deconvolution = mkldnn_deconvolution,
+        shuffle = mkldnn_shuffle,
+        eltwise = mkldnn_eltwise,
+        softmax = mkldnn_softmax,
+        pooling = mkldnn_pooling,
+        lrn = mkldnn_lrn,
+        batch_normalization = mkldnn_batch_normalization,
+        inner_product = mkldnn_inner_product,
+        rnn = mkldnn_rnn,
+    };
+
+    primitive(const_mkldnn_primitive_desc_t c_pd);
+    primitive(const primitive_desc &pd);
+
+    /// Returns the descriptor of the underlying C API primitive.
+    inline const_mkldnn_primitive_desc_t get_primitive_desc() const;
+    // TODO: use the C++ API wrapper structure.
+
+    void execute(struct stream &astream,
+            const std::unordered_map<int, memory> &args) const;
+};
+
+inline mkldnn_primitive_kind_t convert_to_c(primitive::kind akind) {
+    return static_cast<mkldnn_primitive_kind_t>(akind);
+}
+/// Intel(R) MKL-DNN exception class.
+///
+/// This class captures the status returned by the failed C API function, error
+/// message, and, optionally, handle of the primitive that caused the error.
+struct error: public std::exception {
+    mkldnn_status_t status;
+    const char *message;
+
+    /// Constructs an error instance.
+    ///
+    /// @param astatus The error status returned by the C API.
+    /// @param amessage The error message.
+    error(mkldnn_status_t astatus, const char *amessage)
+        : status(astatus), message(amessage) {}
+
+    /// A convenience function for wrapping calls to the C API. Checks the
+    /// return status and throws an #error in case of failure.
+    ///
+    /// @param status The error status returned by the C API.
+    /// @param message The error message.
+    static void wrap_c_api(mkldnn_status_t status, const char *message) {
+        if (status != mkldnn_success)
+            throw error(status, message);
+    }
+};
+
+const_mkldnn_primitive_desc_t primitive::get_primitive_desc() const {
+    const_mkldnn_primitive_desc_t pd;
+    error::wrap_c_api(mkldnn_primitive_get_primitive_desc(get(), &pd),
+            "could not get primitive descriptor by primitive");
+    return pd;
+}
+/// @}
+
+/// @addtogroup cpp_api_enums Common data types and enumerations
+/// A proxy to @ref c_api_types in @ref c_api.
+///
+/// @{
+
+enum scratchpad_mode {
+    scratchpad_mode_library = mkldnn_scratchpad_mode_library,
+    scratchpad_mode_user = mkldnn_scratchpad_mode_user,
+};
+
+inline mkldnn_scratchpad_mode_t convert_to_c(scratchpad_mode mode) {
+    return static_cast<mkldnn_scratchpad_mode_t>(mode);
+}
+
+enum padding_kind {
+    zero = mkldnn_padding_zero
+};
+
+inline mkldnn_padding_kind_t convert_to_c(padding_kind kind) {
+    return static_cast<mkldnn_padding_kind_t>(kind);
+}
+
+enum prop_kind {
+    forward_training = mkldnn_forward_training,
+    forward_scoring = mkldnn_forward_scoring,
+    forward_inference = mkldnn_forward_inference,
+    forward = mkldnn_forward,
+    backward = mkldnn_backward,
+    backward_data = mkldnn_backward_data,
+    backward_weights = mkldnn_backward_weights,
+    backward_bias = mkldnn_backward_bias
+};
+
+inline mkldnn_prop_kind_t convert_to_c(prop_kind kind) {
+    return static_cast<mkldnn_prop_kind_t>(kind);
+}
+
+enum algorithm {
+    algorithm_undef = mkldnn_alg_kind_undef,
+    convolution_auto = mkldnn_convolution_auto,
+    convolution_direct = mkldnn_convolution_direct,
+    convolution_winograd = mkldnn_convolution_winograd,
+    deconvolution_direct = mkldnn_deconvolution_direct,
+    deconvolution_winograd = mkldnn_deconvolution_winograd,
+    eltwise_relu = mkldnn_eltwise_relu,
+    eltwise_tanh = mkldnn_eltwise_tanh,
+    eltwise_elu = mkldnn_eltwise_elu,
+    eltwise_square = mkldnn_eltwise_square,
+    eltwise_abs = mkldnn_eltwise_abs,
+    eltwise_sqrt = mkldnn_eltwise_sqrt,
+    eltwise_linear = mkldnn_eltwise_linear,
+    eltwise_bounded_relu = mkldnn_eltwise_bounded_relu,
+    eltwise_soft_relu = mkldnn_eltwise_soft_relu,
+    eltwise_logistic = mkldnn_eltwise_logistic,
+    lrn_across_channels = mkldnn_lrn_across_channels,
+    lrn_within_channel  = mkldnn_lrn_within_channel,
+    pooling_max = mkldnn_pooling_max,
+    pooling_avg = mkldnn_pooling_avg,
+    pooling_avg_include_padding = mkldnn_pooling_avg_include_padding,
+    pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding,
+    vanilla_rnn = mkldnn_vanilla_rnn,
+    vanilla_lstm = mkldnn_vanilla_lstm,
+    vanilla_gru = mkldnn_vanilla_gru,
+    gru_linear_before_reset = mkldnn_gru_linear_before_reset
+};
+
+inline mkldnn_alg_kind_t convert_to_c(algorithm aalgorithm) {
+    return static_cast<mkldnn_alg_kind_t>(aalgorithm);
+}
+
+enum batch_normalization_flag {
+    use_global_stats = mkldnn_use_global_stats,
+    use_scale_shift = mkldnn_use_scaleshift,
+    fuse_bn_relu = mkldnn_fuse_bn_relu
+};
+
+inline mkldnn_batch_normalization_flag_t convert_to_c(
+        batch_normalization_flag aflag) {
+    return static_cast<mkldnn_batch_normalization_flag_t>(aflag);
+}
+
+enum rnn_direction {
+    unidirectional_left2right = mkldnn_unidirectional_left2right,
+    unidirectional_right2left = mkldnn_unidirectional_right2left,
+    unidirectional = mkldnn_unidirectional,
+    bidirectional_concat = mkldnn_bidirectional_concat,
+    bidirectional_sum = mkldnn_bidirectional_sum,
+};
+
+inline mkldnn_rnn_direction_t convert_to_c(rnn_direction adir) {
+    return static_cast<mkldnn_rnn_direction_t>(adir);
+}
+
+enum query {
+    undef = mkldnn_query_undef,
+
+    query_engine = mkldnn_query_engine,
+    primitive_kind = mkldnn_query_primitive_kind,
+
+    num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32,
+    num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32,
+
+    time_estimate_f64 = mkldnn_query_time_estimate_f64,
+    memory_consumption_s64 = mkldnn_query_memory_consumption_s64,
+
+    query_scratchpad_engine = mkldnn_query_scratchpad_engine,
+
+    impl_info_str = mkldnn_query_impl_info_str,
+
+    op_d = mkldnn_query_op_d,
+    convolution_d = mkldnn_query_convolution_d,
+    deconvolution_d = mkldnn_query_deconvolution_d,
+    shuffle_d = mkldnn_query_shuffle_d,
+    eltwise_d = mkldnn_query_eltwise_d,
+    softmax_d = mkldnn_query_softmax_d,
+    pooling_d = mkldnn_query_pooling_d,
+    lrn_d = mkldnn_query_lrn_d,
+    batch_normalization_d = mkldnn_query_batch_normalization_d,
+    inner_product_d = mkldnn_query_inner_product_d,
+    rnn_d = mkldnn_query_rnn_d,
+
+    src_md = mkldnn_query_src_md,
+    diff_src_md = mkldnn_query_diff_src_md,
+    weights_md = mkldnn_query_weights_md,
+    diff_weights_md = mkldnn_query_diff_weights_md,
+    dst_md = mkldnn_query_dst_md,
+    diff_dst_md = mkldnn_query_diff_dst_md,
+    workspace_md = mkldnn_query_workspace_md,
+    scratchpad_md = mkldnn_query_scratchpad_md,
+};
+
+inline mkldnn_query_t convert_to_c(query aquery) {
+    return static_cast<mkldnn_query_t>(aquery);
+}
+
+/// @}
+
+/// @addtogroup cpp_api_attr Attributes
+/// An extension for controlling primitive behavior.
+///
+/// @sa @ref c_api_attributes in @ref c_api
+/// @{
+
+#ifndef DOXYGEN_SHOULD_SKIP_THIS
+template <> struct handle_traits<mkldnn_post_ops_t> {
+    static constexpr auto destructor = &mkldnn_post_ops_destroy;
+};
+#endif
+
+struct post_ops: public handle<mkldnn_post_ops_t> {
+    post_ops() {
+        mkldnn_post_ops_t result;
+        error::wrap_c_api(mkldnn_post_ops_create(&result),
+                "could not create post operation sequence");
+        reset(result);
+    }
+
+    int len() const { return mkldnn_post_ops_len(get()); }
+
+    primitive::kind kind(int index) const {
+        error::wrap_c_api(
+                index < len() ? mkldnn_success : mkldnn_invalid_arguments,
+                "post_ops index is out of range");
+        return static_cast<primitive::kind>(mkldnn_post_ops_get_kind(get(),
+                    index));
+    }
+
+    void append_sum(float scale = 1.) {
+        error::wrap_c_api(mkldnn_post_ops_append_sum(get(), scale),
+                "could not append sum");
+    }
+
+    void get_params_sum(int index, float &scale) const {
+        error::wrap_c_api(mkldnn_post_ops_get_params_sum(get(), index, &scale),
+                "could not get sum params");
+    }
+
+    void append_eltwise(float scale, algorithm alg, float alpha,
+            float beta) {
+        error::wrap_c_api(mkldnn_post_ops_append_eltwise(get(), scale,
+                    convert_to_c(alg), alpha, beta),
+                "could not append eltwise");
+    }
+
+    void get_params_eltwise(int index, float &scale, algorithm &alg,
+            float &alpha, float &beta) const {
+        mkldnn_alg_kind_t c_alg;
+        error::wrap_c_api(mkldnn_post_ops_get_params_eltwise(get(), index,
+                    &scale, &c_alg, &alpha, &beta),
+                "could not get eltwise params");
+        alg = static_cast<algorithm>(c_alg);
+    }
+};
+
+#ifndef DOXYGEN_SHOULD_SKIP_THIS
+template <> struct handle_traits<mkldnn_primitive_attr_t> {
+    static constexpr auto destructor = &mkldnn_primitive_attr_destroy;
+};
+#endif
+
+struct primitive_attr: public handle<mkldnn_primitive_attr_t> {
+    primitive_attr() {
+        mkldnn_primitive_attr_t result;
+        error::wrap_c_api(mkldnn_primitive_attr_create(&result),
+                "could not create a primitive attr");
+        reset(result);
+    }
+
+    scratchpad_mode get_scratchpad_mode() const {
+        mkldnn_scratchpad_mode_t result;
+        error::wrap_c_api(mkldnn_primitive_attr_get_scratchpad_mode(
+                    get(), &result), "could not get scratchpad mode");
+        return scratchpad_mode(result);
+    }
+
+    void set_scratchpad_mode(scratchpad_mode mode) {
+        error::wrap_c_api(mkldnn_primitive_attr_set_scratchpad_mode(
+                    get(), mkldnn::convert_to_c(mode)),
+                "could not set scratchpad mode");
+    }
+
+    void get_output_scales(int &mask, std::vector<float> &scales) const
+    {
+        mkldnn_dim_t count;
+        int c_mask;
+        const float *c_scales;
+        error::wrap_c_api(mkldnn_primitive_attr_get_output_scales(get(),
+                    &count, &c_mask, &c_scales),
+                "could not get int output scales");
+        scales.resize(count);
+
+        mask = c_mask;
+        for (mkldnn_dim_t c = 0; c < count; ++c)
+            scales[c] = c_scales[c];
+    }
+
+    void set_output_scales(int mask, const std::vector<float> &scales)
+    {
+        error::wrap_c_api(mkldnn_primitive_attr_set_output_scales(get(),
+                    (mkldnn_dim_t)scales.size(), mask, &scales[0]),
+                "could not set int output scales");
+    }
+
+    const post_ops get_post_ops() const {
+        post_ops result;
+        const_mkldnn_post_ops_t c_result;
+        error::wrap_c_api(mkldnn_primitive_attr_get_post_ops(get(), &c_result),
+                "could not get post operation sequence");
+        result.reset(const_cast<mkldnn_post_ops_t>(c_result), true);
+        return result;
+    }
+
+    void set_post_ops(post_ops ops) {
+        error::wrap_c_api(mkldnn_primitive_attr_set_post_ops(get(), ops.get()),
+                "could not set post operation sequence");
+    }
+
+    void set_rnn_data_qparams(const float scale, const float shift)
+    {
+        error::wrap_c_api(mkldnn_primitive_attr_set_rnn_data_qparams(get(),
+                    scale, shift), "could not set rnn data int scale/shift");
+    }
+
+    void set_rnn_weights_qparams(int mask, const std::vector<float> &scales)
+    {
+        error::wrap_c_api(mkldnn_primitive_attr_set_rnn_weights_qparams(get(),
+                    (int)scales.size(), mask, &scales[0]),
+                "could not set rnn weights int scales");
+    }
+};
+
+/// @}
+
+/// @addtogroup cpp_api_engine Engine
+/// Engine operations.
+///
+/// @sa @ref c_api_engine in @ref c_api
+/// @{
+
+#ifndef DOXYGEN_SHOULD_SKIP_THIS
+template <> struct handle_traits<mkldnn_engine_t> {
+    static constexpr auto destructor = &mkldnn_engine_destroy;
+};
+#endif
+
+/// An execution engine.
+struct engine: public handle<mkldnn_engine_t> {
+    friend class primitive;
+    // gcc bug??? using handle::handle;
+
+    /// Kinds of engines.
+    enum kind {
+        /// An unspecified engine
+        any = mkldnn_any_engine,
+        /// CPU engine
+        cpu = mkldnn_cpu,
+    };
+
+    /// Returns the number of engines of a certain kind.
+    ///
+    /// @param akind The kind of engines to count.
+
+    static size_t get_count(kind akind) {
+        return mkldnn_engine_get_count(convert_to_c(akind));
+    }
+
+    /// Constructs an engine.
+    ///
+    /// @param akind The kind of engine to construct.
+    /// @param index The index of the engine. Must be less than the value
+    ///              returned by #get_count() for this particular kind of engine.
+
+    engine(kind akind, size_t index) {
+        mkldnn_engine_t aengine;
+        error::wrap_c_api(
+                mkldnn_engine_create(&aengine,
+                    convert_to_c(akind), index),
+                "could not create an engine");
+        reset(aengine);
+    }
+
+    explicit engine(const mkldnn_engine_t& aengine)
+        : handle(aengine, true) {}
+
+    engine(const handle<mkldnn_primitive_desc_t> &pd) {
+        mkldnn_engine_t engine_q;
+        error::wrap_c_api(
+                mkldnn_primitive_desc_query(pd.get(),
+                    mkldnn::convert_to_c(query_engine), 0, &engine_q),
+                "could not get engine from primitive_desc");
+        reset(engine_q, true);
+    }
+
+    template <class primitive_desc>
+    static engine query(const primitive_desc &pd) {
+        mkldnn_engine_t engine_q;
+        error::wrap_c_api(
+                mkldnn_primitive_desc_query(pd.get(),
+                    mkldnn::convert_to_c(query_engine), 0, &engine_q),
+                "could not get engine from primitive_desc");
+
+        return engine(engine_q);
+    }
+
+private:
+    static mkldnn_engine_kind_t convert_to_c(kind akind) {
+        return static_cast<mkldnn_engine_kind_t>(akind);
+    }
+};
+
+/// @}
+
+/// @addtogroup cpp_api_stream Stream
+/// Execution stream operations
+///
+/// @sa @ref c_api_stream in @ref c_api
+/// @{
+
+#ifndef DOXYGEN_SHOULD_SKIP_THIS
+template <> struct handle_traits<mkldnn_stream_t> {
+    static constexpr auto destructor = &mkldnn_stream_destroy;
+};
+#endif
+
+struct stream: public handle<mkldnn_stream_t> {
+    using handle::handle;
+
+    enum: unsigned {
+        default_flags = mkldnn_stream_default_flags,
+    };
+
+    /// Constructs a stream.
+    stream(const engine &aengine,
+            unsigned flags = static_cast<unsigned>(default_flags)) {
+        mkldnn_stream_t astream;
+        error::wrap_c_api(mkldnn_stream_create(&astream, aengine.get(), flags),
+                "could not create a stream");
+        reset(astream);
+    }
+};
+
+/// @}
+
+/// @addtogroup cpp_api_memory_related Memory and memory related operations
+/// @{
+
+/// @addtogroup cpp_api_memory Memory
+/// A primitive to describe and store data.
+///
+/// For more information, refer to @ref c_api_memory in @ref c_api.
+/// @{
+
+/// Memory that describes the data.
+struct memory: public handle<mkldnn_memory_t> {
+    public:
+    typedef mkldnn_dim_t dim;
+    typedef std::vector<dim> dims;
+
+    template <typename T> static void validate_dims(const std::vector<T> &v) {
+        if (v.size() > MKLDNN_MAX_NDIMS)
+            throw error(mkldnn_invalid_arguments, "invalid dimensions");
+    }
+
+    /// Data type specification. See #mkldnn_data_type_t for a detailed
+    /// description.
+    enum data_type {
+        data_undef = mkldnn_data_type_undef,
+        f32 = mkldnn_f32,
+        s32 = mkldnn_s32,
+        s8 = mkldnn_s8,
+        u8 = mkldnn_u8,
+    };
+
+    /// Memory format tag specification. See #mkldnn_format_tag_t
+    /// for a detailed description.
+    enum format_tag {
+        format_tag_undef = mkldnn_format_tag_undef,
+        any = mkldnn_format_tag_any,
+        a = mkldnn_a,
+        ab = mkldnn_ab,
+        abc = mkldnn_abc,
+        abcd = mkldnn_abcd,
+        abcde = mkldnn_abcde,
+        abcdef = mkldnn_abcdef,
+        abdec = mkldnn_abdec,
+        acb = mkldnn_acb,
+        acbde = mkldnn_acbde,
+        acdb = mkldnn_acdb,
+        acdeb = mkldnn_acdeb,
+        ba = mkldnn_ba,
+        bac = mkldnn_bac,
+        bacd = mkldnn_bacd,
+        bcda = mkldnn_bcda,
+        cba = mkldnn_cba,
+        cdba = mkldnn_cdba,
+        cdeba = mkldnn_cdeba,
+        decab = mkldnn_decab,
+        Abc16a = mkldnn_Abc16a,
+        ABc16a16b = mkldnn_ABc16a16b,
+        aBc16b = mkldnn_aBc16b,
+        ABc16b16a = mkldnn_ABc16b16a,
+        Abc4a = mkldnn_Abc4a,
+        aBc4b = mkldnn_aBc4b,
+        ABc4b16a4b = mkldnn_ABc4b16a4b,
+        ABc4b4a = mkldnn_ABc4b4a,
+        ABc8a16b2a = mkldnn_ABc8a16b2a,
+        ABc8a8b = mkldnn_ABc8a8b,
+        aBc8b = mkldnn_aBc8b,
+        ABc8b16a2b = mkldnn_ABc8b16a2b,
+        ABc8b8a = mkldnn_ABc8b8a,
+        Abcd16a = mkldnn_Abcd16a,
+        ABcd16a16b = mkldnn_ABcd16a16b,
+        aBcd16b = mkldnn_aBcd16b,
+        ABcd16b16a = mkldnn_ABcd16b16a,
+        aBCd16b16c = mkldnn_aBCd16b16c,
+        aBCd16c16b = mkldnn_aBCd16c16b,
+        Abcd4a = mkldnn_Abcd4a,
+        aBcd4b = mkldnn_aBcd4b,
+        ABcd4b16a4b = mkldnn_ABcd4b16a4b,
+        ABcd4b4a = mkldnn_ABcd4b4a,
+        aBCd4c16b4c = mkldnn_aBCd4c16b4c,
+        aBCd4c4b = mkldnn_aBCd4c4b,
+        ABcd8a16b2a = mkldnn_ABcd8a16b2a,
+        ABcd8a8b = mkldnn_ABcd8a8b,
+        aBcd8b = mkldnn_aBcd8b,
+        ABcd8b16a2b = mkldnn_ABcd8b16a2b,
+        aBCd8b16c2b = mkldnn_aBCd8b16c2b,
+        ABcd8b8a = mkldnn_ABcd8b8a,
+        aBCd8b8c = mkldnn_aBCd8b8c,
+        aBCd8c16b2c = mkldnn_aBCd8c16b2c,
+        aBCd8c8b = mkldnn_aBCd8c8b,
+        Abcde16a = mkldnn_Abcde16a,
+        ABcde16a16b = mkldnn_ABcde16a16b,
+        aBcde16b = mkldnn_aBcde16b,
+        ABcde16b16a = mkldnn_ABcde16b16a,
+        aBCde16b16c = mkldnn_aBCde16b16c,
+        aBCde16c16b = mkldnn_aBCde16c16b,
+        aBCde2c8b4c = mkldnn_aBCde2c8b4c,
+        Abcde4a = mkldnn_Abcde4a,
+        aBcde4b = mkldnn_aBcde4b,
+        ABcde4b4a = mkldnn_ABcde4b4a,
+        aBCde4b4c = mkldnn_aBCde4b4c,
+        aBCde4c16b4c = mkldnn_aBCde4c16b4c,
+        aBCde4c4b = mkldnn_aBCde4c4b,
+        Abcde8a = mkldnn_Abcde8a,
+        ABcde8a8b = mkldnn_ABcde8a8b,
+        aBcde8b = mkldnn_aBcde8b,
+        ABcde8b16a2b = mkldnn_ABcde8b16a2b,
+        aBCde8b16c2b = mkldnn_aBCde8b16c2b,
+        ABcde8b8a = mkldnn_ABcde8b8a,
+        aBCde8b8c = mkldnn_aBCde8b8c,
+        aBCde8c16b2c = mkldnn_aBCde8c16b2c,
+        aBCde8c8b = mkldnn_aBCde8c8b,
+        aBcdef16b = mkldnn_aBcdef16b,
+        aBCdef16b16c = mkldnn_aBCdef16b16c,
+        aBCdef16c16b = mkldnn_aBCdef16c16b,
+        aBcdef4b = mkldnn_aBcdef4b,
+        aBCdef4c4b = mkldnn_aBCdef4c4b,
+        aBCdef8b8c = mkldnn_aBCdef8b8c,
+        aBCdef8c16b2c = mkldnn_aBCdef8c16b2c,
+        aBCdef8c8b = mkldnn_aBCdef8c8b,
+        aBdc16b = mkldnn_aBdc16b,
+        aBdc4b = mkldnn_aBdc4b,
+        aBdc8b = mkldnn_aBdc8b,
+        aBdec16b = mkldnn_aBdec16b,
+        aBdec4b = mkldnn_aBdec4b,
+        aBdec8b = mkldnn_aBdec8b,
+        aBdefc16b = mkldnn_aBdefc16b,
+        aBdefc4b = mkldnn_aBdefc4b,
+        aBdefc8b = mkldnn_aBdefc8b,
+        Acb16a = mkldnn_Acb16a,
+        Acb4a = mkldnn_Acb4a,
+        Acb8a = mkldnn_Acb8a,
+        aCBd16b16c = mkldnn_aCBd16b16c,
+        aCBde16b16c = mkldnn_aCBde16b16c,
+        Acdb16a = mkldnn_Acdb16a,
+        Acdb4a = mkldnn_Acdb4a,
+        Acdb8a = mkldnn_Acdb8a,
+        Acdeb16a = mkldnn_Acdeb16a,
+        Acdeb4a = mkldnn_Acdeb4a,
+        Acdeb8a = mkldnn_Acdeb8a,
+        BAc16a16b = mkldnn_BAc16a16b,
+        BAcd16a16b = mkldnn_BAcd16a16b,
+        format_tag_last = mkldnn_format_tag_last,
+
+        x = mkldnn_x,
+        nc = mkldnn_nc,
+        cn = mkldnn_cn,
+        ncw = mkldnn_ncw,
+        nwc = mkldnn_nwc,
+        nchw = mkldnn_nchw,
+        nhwc = mkldnn_nhwc,
+        chwn = mkldnn_chwn,
+        ncdhw = mkldnn_ncdhw,
+        ndhwc = mkldnn_ndhwc,
+        oi = mkldnn_oi,
+        io = mkldnn_io,
+        oiw = mkldnn_oiw,
+        wio = mkldnn_wio,
+        oihw = mkldnn_oihw,
+        hwio = mkldnn_hwio,
+        ihwo = mkldnn_ihwo,
+        iohw = mkldnn_iohw,
+        oidhw = mkldnn_oidhw,
+        dhwio = mkldnn_dhwio,
+        goiw = mkldnn_goiw,
+        goihw = mkldnn_goihw,
+        hwigo = mkldnn_hwigo,
+        giohw = mkldnn_giohw,
+        goidhw = mkldnn_goidhw,
+        tnc = mkldnn_tnc,
+        ntc = mkldnn_ntc,
+        ldsnc = mkldnn_ldsnc,
+        ldigo = mkldnn_ldigo,
+        ldgoi = mkldnn_ldgoi,
+        ldgo = mkldnn_ldgo,
+        nCdhw16c = mkldnn_nCdhw16c,
+        nCdhw4c = mkldnn_nCdhw4c,
+        nCdhw8c = mkldnn_nCdhw8c,
+        nChw16c = mkldnn_nChw16c,
+        nChw4c = mkldnn_nChw4c,
+        nChw8c = mkldnn_nChw8c,
+        nCw16c = mkldnn_nCw16c,
+        nCw4c = mkldnn_nCw4c,
+        nCw8c = mkldnn_nCw8c,
+        IOw16o16i = mkldnn_IOw16o16i,
+        OIw16i16o = mkldnn_OIw16i16o,
+        OIw16o16i = mkldnn_OIw16o16i,
+        Oiw16o = mkldnn_Oiw16o,
+        OIw4i16o4i = mkldnn_OIw4i16o4i,
+        OIw4i4o = mkldnn_OIw4i4o,
+        Oiw4o = mkldnn_Oiw4o,
+        OIw8i16o2i = mkldnn_OIw8i16o2i,
+        OIw8i8o = mkldnn_OIw8i8o,
+        OIw8o16i2o = mkldnn_OIw8o16i2o,
+        OIw8o8i = mkldnn_OIw8o8i,
+        Owi16o = mkldnn_Owi16o,
+        Owi4o = mkldnn_Owi4o,
+        Owi8o = mkldnn_Owi8o,
+        IOhw16o16i = mkldnn_IOhw16o16i,
+        Ohwi16o = mkldnn_Ohwi16o,
+        Ohwi4o = mkldnn_Ohwi4o,
+        Ohwi8o = mkldnn_Ohwi8o,
+        OIhw16i16o = mkldnn_OIhw16i16o,
+        OIhw16o16i = mkldnn_OIhw16o16i,
+        Oihw16o = mkldnn_Oihw16o,
+        OIhw4i16o4i = mkldnn_OIhw4i16o4i,
+        OIhw4i4o = mkldnn_OIhw4i4o,
+        Oihw4o = mkldnn_Oihw4o,
+        OIhw8i16o2i = mkldnn_OIhw8i16o2i,
+        OIhw8i8o = mkldnn_OIhw8i8o,
+        OIhw8o16i2o = mkldnn_OIhw8o16i2o,
+        OIhw8o8i = mkldnn_OIhw8o8i,
+        Odhwi16o = mkldnn_Odhwi16o,
+        Odhwi4o = mkldnn_Odhwi4o,
+        Odhwi8o = mkldnn_Odhwi8o,
+        OIdhw16i16o = mkldnn_OIdhw16i16o,
+        OIdhw16o16i = mkldnn_OIdhw16o16i,
+        Oidhw16o = mkldnn_Oidhw16o,
+        OIdhw4i4o = mkldnn_OIdhw4i4o,
+        Oidhw4o = mkldnn_Oidhw4o,
+        OIdhw8i16o2i = mkldnn_OIdhw8i16o2i,
+        OIdhw8i8o = mkldnn_OIdhw8i8o,
+        OIdhw8o8i = mkldnn_OIdhw8o8i,
+        gIOw16o16i = mkldnn_gIOw16o16i,
+        gOIw16i16o = mkldnn_gOIw16i16o,
+        gOIw16o16i = mkldnn_gOIw16o16i,
+        gOiw16o = mkldnn_gOiw16o,
+        gOIw4i16o4i = mkldnn_gOIw4i16o4i,
+        gOIw4i4o = mkldnn_gOIw4i4o,
+        gOiw4o = mkldnn_gOiw4o,
+        gOIw8i16o2i = mkldnn_gOIw8i16o2i,
+        gOIw8i8o = mkldnn_gOIw8i8o,
+        gOIw8o16i2o = mkldnn_gOIw8o16i2o,
+        gOIw8o8i = mkldnn_gOIw8o8i,
+        gOwi16o = mkldnn_gOwi16o,
+        gOwi4o = mkldnn_gOwi4o,
+        gOwi8o = mkldnn_gOwi8o,
+        gIOhw16o16i = mkldnn_gIOhw16o16i,
+        gOhwi16o = mkldnn_gOhwi16o,
+        gOhwi4o = mkldnn_gOhwi4o,
+        gOhwi8o = mkldnn_gOhwi8o,
+        Goihw16g = mkldnn_Goihw16g,
+        gOIhw16i16o = mkldnn_gOIhw16i16o,
+        gOIhw16o16i = mkldnn_gOIhw16o16i,
+        gOihw16o = mkldnn_gOihw16o,
+        gOIhw2i8o4i = mkldnn_gOIhw2i8o4i,
+        gOIhw4i16o4i = mkldnn_gOIhw4i16o4i,
+        gOIhw4i4o = mkldnn_gOIhw4i4o,
+        gOIhw4o4i = mkldnn_gOIhw4o4i,
+        gOihw4o = mkldnn_gOihw4o,
+        Goihw8g = mkldnn_Goihw8g,
+        gOIhw8i16o2i = mkldnn_gOIhw8i16o2i,
+        gOIhw8i8o = mkldnn_gOIhw8i8o,
+        gOIhw8o16i2o = mkldnn_gOIhw8o16i2o,
+        gOIhw8o8i = mkldnn_gOIhw8o8i,
+        gOdhwi16o = mkldnn_gOdhwi16o,
+        gOdhwi4o = mkldnn_gOdhwi4o,
+        gOdhwi8o = mkldnn_gOdhwi8o,
+        gOIdhw16i16o = mkldnn_gOIdhw16i16o,
+        gOIdhw16o16i = mkldnn_gOIdhw16o16i,
+        gOidhw16o = mkldnn_gOidhw16o,
+        gOIdhw4i4o = mkldnn_gOIdhw4i4o,
+        gOidhw4o = mkldnn_gOidhw4o,
+        gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i,
+        gOIdhw8i8o = mkldnn_gOIdhw8i8o,
+        gOIdhw8o8i = mkldnn_gOIdhw8o8i,
+    };
+
+    /// A memory descriptor.
+    struct desc {
+        friend struct memory;
+        /// The underlying C API data structure.
+        mkldnn_memory_desc_t data;
+
+        /// Constructs a zero memory descriptor
+        desc(): data() {}
+
+        /// Constructs a memory descriptor.
+        ///
+        /// @param adims Data dimensions
+        /// @param adata_type Data precision/type.
+        /// @param aformat Data layout format tag.
+        desc(const dims &adims, data_type adata_type,
+                format_tag aformat) {
+            validate_dims(adims);
+            error::wrap_c_api(mkldnn_memory_desc_init_by_tag(&data, (int)adims.size(),
+                        adims.size() == 0 ? nullptr : &adims[0],
+                        convert_to_c(adata_type), convert_to_c(aformat)),
+                    "could not initialize a memory descriptor");
+        }
+
+        /// Constructs a memory descriptor from a C API data structure.
+        ///
+        /// @param adata A C API #mkldnn_memory_desc_t structure.
+        desc(const mkldnn_memory_desc_t &adata): data(adata) {}
+
+        /// Constructs a sub-memory descriptor
+        //
+        /// @param adims Sizes of a sub-memory
+        /// @param offsets Offsets of a sub-memory
+        desc submemory_desc(const dims &adims, const dims &offsets) {
+            mkldnn_memory_desc_t sub_md;
+            error::wrap_c_api(mkldnn_memory_desc_init_submemory(&sub_md,
+                        &data, &adims[0], &offsets[0]),
+                    "could not initialize a sub-memory");
+            return desc(sub_md);
+        }
+
+        /// Returns the number of bytes required to allocate the memory described
+        /// including the padding area.
+        size_t get_size() const { return mkldnn_memory_desc_get_size(&data); }
+
+        bool operator==(const desc &other) const {
+            return mkldnn_memory_desc_equal(&data, &other.data) != 0;
+        }
+
+        bool operator!=(const desc &other) const { return !operator==(other); }
+    };
+
+    /// Constructs a memory.
+    ///
+    /// @param md Memory descriptor.
+    /// @param aengine Engine.
+    /// @param ahandle Native handle.
+    memory(const desc &md, const engine &aengine, void *ahandle) {
+        mkldnn_memory_t result;
+        error::wrap_c_api(mkldnn_memory_create(&result, &md.data,
+                    aengine.get(), ahandle), "could not create a memory");
+        reset(result);
+    }
+
+    /// Constructs a memory.
+    ///
+    /// @param md Memory descriptor.
+    /// @param aengine Engine.
+    memory(const desc &md, const engine &aengine)
+        : memory(md, aengine, MKLDNN_NATIVE_HANDLE_ALLOCATE) {}
+
+    /// Returns the descriptor of the memory.
+    desc get_desc() const {
+        const mkldnn_memory_desc_t *cdesc;
+        error::wrap_c_api(mkldnn_memory_get_memory_desc(get(), &cdesc),
+                "could not get memory descriptor from a memory");
+        return desc(*cdesc);
+    }
+
+    /// Returns the engine of the memory.
+    engine get_engine() const {
+        mkldnn_engine_t engine_q;
+        error::wrap_c_api(mkldnn_memory_get_engine(get(), &engine_q),
+                "could not get engine from a memory");
+        return engine(engine_q);
+    }
+
+    /// Returns a handle of the data contained in the memory.
+    ///
+    /// On the CPU engine, this is a pointer to the allocated memory.
+    void *get_data_handle() const {
+        void *handle;
+        error::wrap_c_api(mkldnn_memory_get_data_handle(get(), &handle),
+                "could not get native handle");
+        return handle;
+    }
+
+    void set_data_handle(void *handle) const {
+        error::wrap_c_api(mkldnn_memory_set_data_handle(get(), handle),
+                "could not set native handle");
+    }
+
+    // Must go away or be private:
+    static mkldnn_data_type_t convert_to_c(data_type adata_type) {
+        return static_cast<mkldnn_data_type_t>(adata_type);
+    }
+    static mkldnn_format_tag_t convert_to_c(format_tag aformat) {
+        return static_cast<mkldnn_format_tag_t>(aformat);
+    }
+};
+
+inline bool operator==(mkldnn_data_type_t a, memory::data_type b) {
+    return a == memory::convert_to_c(b);
+}
+inline bool operator!=(mkldnn_data_type_t a, memory::data_type b) {
+    return !(a == b);
+}
+inline bool operator==(memory::data_type a, mkldnn_data_type_t b) {
+    return b == a;
+}
+inline bool operator!=(memory::data_type a, mkldnn_data_type_t b) {
+    return !(a == b);
+}
+
+inline bool operator==(mkldnn_format_tag_t a, memory::format_tag b) {
+    return a == memory::convert_to_c(b);
+}
+inline bool operator!=(mkldnn_format_tag_t a, memory::format_tag b) {
+    return !(a == b);
+}
+inline bool operator==(memory::format_tag a, mkldnn_format_tag_t b) {
+    return b == a;
+}
+inline bool operator!=(memory::format_tag a, mkldnn_format_tag_t b) {
+    return !(a == b);
+}
+
+/// @}
+
+/// @addtogroup cpp_api_reorder Reorder
+/// A primitive to copy data between memory formats.
+///
+/// @sa @ref c_api_reorder in @ref c_api
+/// @{
+
+struct reorder : public primitive {
+    struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
+        primitive_desc(const engine &src_engine, const memory::desc &src_md,
+                const engine &dst_engine, const memory::desc &dst_md,
+                const primitive_attr &aattr) {
+            mkldnn_primitive_desc_t result;
+            error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result,
+                        src_engine.get(), &src_md.data,
+                        dst_engine.get(), &dst_md.data, aattr.get()),
+                    "could not create a reorder primitive descriptor");
+            reset(result);
+        }
+
+        primitive_desc(const engine &src_engine, const memory::desc &src_md,
+                const engine &dst_engine, const memory::desc &dst_md) {
+            mkldnn_primitive_desc_t result;
+            error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result,
+                        src_engine.get(), &src_md.data,
+                        dst_engine.get(), &dst_md.data, nullptr),
+                    "could not create a reorder primitive descriptor");
+            reset(result);
+        }
+
+        primitive_desc(const memory &src, const memory &dst,
+                const primitive_attr &aattr) {
+            mkldnn_primitive_desc_t result;
+            auto src_md = src.get_desc();
+            auto dst_md = dst.get_desc();
+            error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result,
+                        src.get_engine().get(), &src_md.data,
+                        dst.get_engine().get(), &dst_md.data, aattr.get()),
+                    "could not create a reorder primitive descriptor");
+            reset(result);
+        }
+
+        primitive_desc(const memory &src, const memory &dst) {
+            mkldnn_primitive_desc_t result;
+            auto src_md = src.get_desc();
+            auto dst_md = dst.get_desc();
+            error::wrap_c_api(mkldnn_reorder_primitive_desc_create(&result,
+                        src.get_engine().get(), &src_md.data,
+                        dst.get_engine().get(), &dst_md.data, nullptr),
+                    "could not create a reorder primitive descriptor");
+            reset(result);
+        }
+
+        memory::desc scratchpad_desc() const {
+            const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md(
+                    get(), mkldnn::convert_to_c(scratchpad_md), 0);
+            if (cdesc == nullptr)
+                return memory::desc();
+            return memory::desc(*cdesc);
+        }
+
+        engine scratchpad_engine() {
+            mkldnn_engine_t engine_q;
+            error::wrap_c_api(
+                mkldnn_primitive_desc_query(get(),
+                    mkldnn::convert_to_c(query_scratchpad_engine), 0, &engine_q),
+                "could not get scratchpad engine from reorder primitive_desc");
+
+            return engine(engine_q);
+        }
+
+        engine get_engine() { return engine::query(*this); }
+    };
+
+    reorder(const primitive_desc &pd): primitive(pd.get()) {}
+
+    reorder(const memory &src, const memory &dst):
+        primitive(primitive_desc(src, dst).get()) {}
+
+    void execute(stream astream, memory &src, memory &dst) {
+        primitive::execute(astream,
+                {{MKLDNN_ARG_FROM, src}, {MKLDNN_ARG_TO, dst}});
+    }
+};
+
+/// @}
+
+/// @addtogroup cpp_api_concat Concat
+/// A primitive to concatenate data by arbitrary dimension.
+///
+/// @sa @ref c_api_concat in @ref c_api
+/// @{
+
+struct concat : public primitive {
+    struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
+        std::vector<mkldnn_memory_desc_t> cpp_to_c(
+                const std::vector<memory::desc> &srcs) {
+            std::vector<mkldnn_memory_desc_t> c_api_srcs;
+            c_api_srcs.reserve(srcs.size());
+            for (const auto &s : srcs) c_api_srcs.push_back(s.data);
+            return c_api_srcs;
+        }
+
+        primitive_desc(const memory::desc &dst, int concat_dimension,
+                const std::vector<memory::desc> &srcs, const engine &aengine) {
+            auto c_api_srcs = cpp_to_c(srcs);
+
+            mkldnn_primitive_desc_t result;
+            error::wrap_c_api(mkldnn_concat_primitive_desc_create(
+                    &result, &dst.data, (int)c_api_srcs.size(),
+                    concat_dimension, &c_api_srcs[0], nullptr, aengine.get()),
+                "could not create a concat primitive descriptor");
+            reset(result);
+        }
+
+        primitive_desc(int concat_dimension,
+                const std::vector<memory::desc> &srcs, const engine &aengine) {
+            auto c_api_srcs = cpp_to_c(srcs);
+
+            mkldnn_primitive_desc_t result;
+            error::wrap_c_api(mkldnn_concat_primitive_desc_create(
+                    &result, nullptr, (int)c_api_srcs.size(),
+                    concat_dimension, &c_api_srcs[0], nullptr, aengine.get()),
+                "could not create a concat primitive descriptor");
+            reset(result);
+        }
+
+        memory::desc dst_desc() const {
+            const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md(
+                    get(), mkldnn::convert_to_c(dst_md), 0);
+            error::wrap_c_api(
+                    cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success,
+                    "could not get a dst memory descriptor");
+            return memory::desc(*cdesc);
+        }
+
+        memory::desc scratchpad_desc() const {
+            const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md(
+                    get(), mkldnn::convert_to_c(scratchpad_md), 0);
+            if (cdesc == nullptr)
+                return memory::desc();
+            return memory::desc(*cdesc);
+        }
+
+        engine get_engine() { return engine::query(*this); }
+    };
+
+    concat(const primitive_desc &pd): primitive(pd.get()) {}
+};
+
+/// @}
+
+/// @addtogroup cpp_api_sum Sum
+/// A primitive to sum data.
+///
+/// @sa @ref c_api_sum in @ref c_api
+/// @{
+
+struct sum : public primitive {
+    struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
+        std::vector<mkldnn_memory_desc_t> cpp_to_c(
+                const std::vector<memory::desc> &srcs) {
+            std::vector<mkldnn_memory_desc_t> c_api_srcs;
+            c_api_srcs.reserve(srcs.size());
+            for (const auto &s : srcs) c_api_srcs.push_back(s.data);
+            return c_api_srcs;
+        }
+
+        primitive_desc(const memory::desc &dst,
+                const std::vector<float> &scales,
+                const std::vector<memory::desc> &srcs, const engine &aengine) {
+            error::wrap_c_api(scales.size() == srcs.size()
+                    ? mkldnn_success : mkldnn_invalid_arguments,
+                "number of scales not equal to number of srcs");
+
+            auto c_api_srcs = cpp_to_c(srcs);
+
+            mkldnn_primitive_desc_t result;
+            error::wrap_c_api(mkldnn_sum_primitive_desc_create(
+                    &result, &dst.data, (int)c_api_srcs.size(),
+                    &scales[0], &c_api_srcs[0], nullptr, aengine.get()),
+                "could not create a sum primitive descriptor");
+            reset(result);
+        }
+
+        primitive_desc(const std::vector<float> &scales,
+                const std::vector<memory::desc> &srcs, const engine &aengine) {
+            error::wrap_c_api(scales.size() == srcs.size()
+                    ? mkldnn_success : mkldnn_invalid_arguments,
+                "number of scales not equal to number of srcs");
+
+            auto c_api_srcs = cpp_to_c(srcs);
+            mkldnn_primitive_desc_t result;
+            error::wrap_c_api(mkldnn_sum_primitive_desc_create(&result,
+                        nullptr, (int)c_api_srcs.size(), &scales[0],
+                        &c_api_srcs[0], nullptr, aengine.get()),
+                    "could not create a sum primitive descriptor");
+            reset(result);
+        }
+
+        memory::desc dst_desc() const {
+            const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md(
+                    get(), mkldnn::convert_to_c(dst_md), 0);
+            error::wrap_c_api(
+                    cdesc == nullptr ? mkldnn_runtime_error : mkldnn_success,
+                    "could not get a dst memory descriptor");
+            return memory::desc(*cdesc);
+        }
+
+        memory::desc scratchpad_desc() const {
+            const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md(
+                    get(), mkldnn::convert_to_c(scratchpad_md), 0);
+            if (cdesc == nullptr)
+                return memory::desc();
+            return memory::desc(*cdesc);
+        }
+
+        engine get_engine() { return engine::query(*this); }
+    };
+
+    sum(const primitive_desc &pd): primitive(pd.get()) {}
+};
+
+/// @}
+
+/// @}
+
+/// @addtogroup cpp_api_primitives Primitives
+/// @{
+
+/// @addtogroup cpp_api_primitive_descriptors Primitive descriptors
+/// @{
+
+/// A base class for all primitive descriptors.
+struct primitive_desc : public handle<mkldnn_primitive_desc_t> {
+    primitive_desc(const_mkldnn_op_desc_t desc, const primitive_attr *attr,
+            const engine &e, const_mkldnn_primitive_desc_t hint_fwd_pd) {
+        mkldnn_primitive_desc_iterator_t iterator = nullptr;
+        mkldnn_status_t status = mkldnn_primitive_desc_iterator_create(
+                &iterator, desc, attr ? attr->get() : nullptr, e.get(),
+                hint_fwd_pd);
+        error::wrap_c_api(status,
+                "could not create a primitive descriptor iterator");
+        pd_iterator.reset(iterator);
+        fetch_impl();
+    }
+
+    engine get_engine() { return engine::query(*this); }
+
+    primitive_attr get_primitive_attr() const {
+        const_mkldnn_primitive_attr_t const_cattr;
+        error::wrap_c_api(mkldnn_primitive_desc_get_attr(get(), &const_cattr),
+                "could not get attributes");
+        mkldnn_primitive_attr_t cattr;
+        error::wrap_c_api(mkldnn_primitive_attr_clone(&cattr, const_cattr),
+                "could not clone attributes");
+
+        primitive_attr attr;
+        attr.reset(cattr);
+        return attr;
+    }
+
+    /// Returns implementation name
+    const char *impl_info_str() const {
+        const char *res;
+        error::wrap_c_api(mkldnn_primitive_desc_query(get(),
+                    mkldnn_query_impl_info_str, 0, &res),
+                "could not query implementation info string");
+        return res;
+    }
+
+    /// Queries the memory::dim value (same as int64_t)
+    memory::dim query_s64(query q) const {
+        memory::dim res;
+        mkldnn_status_t status = mkldnn_primitive_desc_query(get(),
+                mkldnn::convert_to_c(q), 0, &res);
+        return status == mkldnn_success ? res : 0;
+    }
+
+    /// Advances the next implementation for the given op descriptor.
+    ///
+    /// Returns:
+    /// - @c true on success
+    /// - @c false if the last implementation reached, and
+    ///   the primitive descriptor itself is kept unchanged
+    bool next_impl() {
+        mkldnn_status_t status = mkldnn_primitive_desc_iterator_next(
+                pd_iterator.get());
+        if (status == mkldnn_iterator_ends) return false;
+        error::wrap_c_api(status, "primitive descriptor iterator next failed");
+
+        fetch_impl();
+        return true;
+    }
+
+    /// Queries and returns requested memory descriptor.
+    memory::desc query_md(query what, int idx = 0) const {
+        std::vector<query> valid_q{src_md, diff_src_md, weights_md,
+            diff_weights_md, dst_md, diff_dst_md, workspace_md, scratchpad_md};
+        if (!std::any_of(valid_q.cbegin(), valid_q.cend(),
+                    [=](query q) { return what == q; }))
+            throw error(mkldnn_invalid_arguments, "invalid memory query");
+
+        const mkldnn_memory_desc_t *cdesc = mkldnn_primitive_desc_query_md(
+                get(), mkldnn::convert_to_c(what), idx);
+        if (cdesc == nullptr) return memory::desc();
+
+        return memory::desc(*cdesc);
+    }
+
+    // register specialized queries, e.g. src_desc()
+#   define REG_QUERY_MD(name, what, idx) \
+    memory::desc name ## _desc() const { return query_md(what ## _md, idx); }
+
+  private:
+    handle<mkldnn_primitive_desc_iterator_t> pd_iterator;
+    void fetch_impl() {
+        mkldnn_primitive_desc_t pd = mkldnn_primitive_desc_iterator_fetch(
+                pd_iterator.get());
+        error::wrap_c_api(pd != nullptr ? mkldnn_success : mkldnn_runtime_error,
+                "could not fetch a primitive descriptor from the iterator");
+        reset(pd);
+    }
+};
+
+/// @}
+
+/// @addtogroup cpp_api_convolution Convolution
+/// A primitive to compute convolution using different algorithms.
+///
+/// @sa @ref c_api_convolution in @ref c_api
+/// @{
+
+struct convolution_forward: public primitive {
+    struct desc {
+        mkldnn_convolution_desc_t data;
+        desc(prop_kind aprop_kind, algorithm aalgorithm,
+                const memory::desc &src_desc,
+                const memory::desc &weights_desc,
+                const memory::desc &bias_desc,
+                const memory::desc &dst_desc,
+                const memory::dims strides,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data,
+                        mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
+                        &src_desc.data, &weights_desc.data, &bias_desc.data,
+                        &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
+                        mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a convolution forward descriptor");
+        }
+        desc(prop_kind aprop_kind, algorithm aalgorithm,
+                const memory::desc &src_desc,
+                const memory::desc &weights_desc,
+                const memory::desc &dst_desc,
+                const memory::dims strides,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(mkldnn_convolution_forward_desc_init(&data,
+                        mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
+                        &src_desc.data, &weights_desc.data, nullptr,
+                        &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
+                        mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a convolution forward descriptor");
+        }
+        desc(prop_kind aprop_kind, algorithm aalgorithm,
+                const memory::desc &src_desc,
+                const memory::desc &weights_desc,
+                const memory::desc &bias_desc,
+                const memory::desc &dst_desc,
+                const memory::dims strides,
+                const memory::dims dilates,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(dilates);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(
+                mkldnn_dilated_convolution_forward_desc_init(&data,
+                    mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
+                        &src_desc.data, &weights_desc.data, &bias_desc.data,
+                        &dst_desc.data, &strides[0], &dilates[0],
+                        &padding_l[0], &padding_r[0],
+                        mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a dilated convolution forward descriptor");
+        }
+        desc(prop_kind aprop_kind, algorithm aalgorithm,
+                const memory::desc &src_desc,
+                const memory::desc &weights_desc,
+                const memory::desc &dst_desc,
+                const memory::dims strides,
+                const memory::dims dilates,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(dilates);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(
+                mkldnn_dilated_convolution_forward_desc_init(&data,
+                    mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
+                        &src_desc.data, &weights_desc.data, nullptr,
+                        &dst_desc.data, &strides[0], &dilates[0],
+                        &padding_l[0], &padding_r[0],
+                        mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a dilated convolution forward descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
+
+        REG_QUERY_MD(src, src, 0);
+        REG_QUERY_MD(weights, weights, 0);
+        REG_QUERY_MD(bias, weights, 1);
+        REG_QUERY_MD(dst, dst, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    convolution_forward(const primitive_desc &pd): primitive(pd) {}
+};
+
+struct convolution_backward_data : public primitive {
+    struct desc {
+        mkldnn_convolution_desc_t data;
+        desc(algorithm aalgorithm,
+                const memory::desc &diff_src_desc,
+                const memory::desc &weights_desc,
+                const memory::desc &diff_dst_desc,
+                const memory::dims strides,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(mkldnn_convolution_backward_data_desc_init(
+                        &data, convert_to_c(aalgorithm), &diff_src_desc.data,
+                        &weights_desc.data, &diff_dst_desc.data,
+                        &strides[0], &padding_l[0], &padding_r[0],
+                        mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a convolution backward data descriptor");
+        }
+        desc(algorithm aalgorithm,
+                const memory::desc &diff_src_desc,
+                const memory::desc &weights_desc,
+                const memory::desc &diff_dst_desc,
+                const memory::dims strides,
+                const memory::dims dilates,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(dilates);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(
+                mkldnn_dilated_convolution_backward_data_desc_init(
+                    &data, convert_to_c(aalgorithm), &diff_src_desc.data,
+                    &weights_desc.data, &diff_dst_desc.data,
+                    &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
+                    mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a convolution backward data descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e,
+                const convolution_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
+                const convolution_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
+
+        REG_QUERY_MD(diff_src, diff_src, 0);
+        REG_QUERY_MD(weights, weights, 0);
+        REG_QUERY_MD(diff_dst, diff_dst, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    convolution_backward_data(const primitive_desc &pd): primitive(pd) {}
+};
+
+struct convolution_backward_weights : public primitive {
+    struct desc {
+        mkldnn_convolution_desc_t data;
+        desc(algorithm aalgorithm,
+                const memory::desc &src_desc,
+                const memory::desc &diff_weights_desc,
+                const memory::desc &diff_bias_desc,
+                const memory::desc &diff_dst_desc,
+                const memory::dims strides,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init(
+                        &data, convert_to_c(aalgorithm), &src_desc.data,
+                        &diff_weights_desc.data, &diff_bias_desc.data,
+                        &diff_dst_desc.data,
+                        &strides[0], &padding_l[0], &padding_r[0],
+                        mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a convolution backward weights descriptor");
+        }
+        desc(algorithm aalgorithm,
+                const memory::desc &src_desc,
+                const memory::desc &diff_weights_desc,
+                const memory::desc &diff_dst_desc,
+                const memory::dims strides,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(mkldnn_convolution_backward_weights_desc_init(
+                        &data, convert_to_c(aalgorithm), &src_desc.data,
+                        &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
+                        &strides[0], &padding_l[0], &padding_r[0],
+                        mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a convolution backward weights descriptor");
+        }
+        desc(algorithm aalgorithm,
+                const memory::desc &src_desc,
+                const memory::desc &diff_weights_desc,
+                const memory::desc &diff_bias_desc,
+                const memory::desc &diff_dst_desc,
+                const memory::dims strides,
+                const memory::dims dilates,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(dilates);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init(
+                        &data, convert_to_c(aalgorithm), &src_desc.data,
+                        &diff_weights_desc.data, &diff_bias_desc.data,
+                        &diff_dst_desc.data,
+                        &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
+                        mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a convolution backward weights descriptor");
+        }
+        desc(algorithm aalgorithm,
+                const memory::desc &src_desc,
+                const memory::desc &diff_weights_desc,
+                const memory::desc &diff_dst_desc,
+                const memory::dims strides,
+                const memory::dims dilates,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(dilates);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(mkldnn_dilated_convolution_backward_weights_desc_init(
+                        &data, convert_to_c(aalgorithm), &src_desc.data,
+                        &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
+                        &strides[0], &dilates[0],  &padding_l[0], &padding_r[0],
+                        mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a convolution backward weights descriptor");
+        }
+
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e,
+                const convolution_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
+                const convolution_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
+
+        REG_QUERY_MD(src, src, 0);
+        REG_QUERY_MD(diff_weights, diff_weights, 0);
+        REG_QUERY_MD(diff_bias, diff_weights, 1);
+        REG_QUERY_MD(diff_dst, diff_dst, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    convolution_backward_weights(const primitive_desc &pd): primitive(pd) {}
+};
+
+/// @}
+//
+/// @addtogroup cpp_api_deconvolution Deconvolution
+/// A primitive to compute deconvolution using different algorithms.
+///
+/// @sa @ref c_api_deconvolution in @ref c_api
+/// @{
+
+struct deconvolution_forward: public primitive {
+    struct desc {
+        mkldnn_deconvolution_desc_t data;
+        desc(prop_kind aprop_kind, algorithm aalgorithm,
+                const memory::desc &src_desc,
+                const memory::desc &weights_desc,
+                const memory::desc &bias_desc,
+                const memory::desc &dst_desc,
+                const memory::dims strides,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data,
+                        mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
+                        &src_desc.data, &weights_desc.data, &bias_desc.data,
+                        &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
+                        mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a deconvolution forward descriptor");
+        }
+        desc(prop_kind aprop_kind, algorithm aalgorithm,
+                const memory::desc &src_desc,
+                const memory::desc &weights_desc,
+                const memory::desc &dst_desc,
+                const memory::dims strides,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(mkldnn_deconvolution_forward_desc_init(&data,
+                        mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
+                        &src_desc.data, &weights_desc.data, nullptr,
+                        &dst_desc.data, &strides[0], &padding_l[0], &padding_r[0],
+                        mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a deconvolution forward descriptor");
+        }
+        desc(prop_kind aprop_kind, algorithm aalgorithm,
+                const memory::desc &src_desc,
+                const memory::desc &weights_desc,
+                const memory::desc &bias_desc,
+                const memory::desc &dst_desc,
+                const memory::dims strides,
+                const memory::dims dilates,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(dilates);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data,
+                        mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
+                        &src_desc.data, &weights_desc.data, &bias_desc.data,
+                        &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
+                        &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a dilated deconvolution forward descriptor");
+        }
+        desc(prop_kind aprop_kind, algorithm aalgorithm,
+                const memory::desc &src_desc,
+                const memory::desc &weights_desc,
+                const memory::desc &dst_desc,
+                const memory::dims strides,
+                const memory::dims dilates,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(dilates);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(mkldnn_dilated_deconvolution_forward_desc_init(&data,
+                        mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
+                        &src_desc.data, &weights_desc.data, nullptr,
+                        &dst_desc.data, &strides[0], &dilates[0], &padding_l[0],
+                        &padding_r[0], mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a dilated deconvolution forward descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
+
+        REG_QUERY_MD(src, src, 0);
+        REG_QUERY_MD(weights, weights, 0);
+        REG_QUERY_MD(bias, weights, 1);
+        REG_QUERY_MD(dst, dst, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    deconvolution_forward(const primitive_desc &pd): primitive(pd) {}
+};
+
+struct deconvolution_backward_data : public primitive {
+    struct desc {
+        mkldnn_deconvolution_desc_t data;
+        desc(algorithm aalgorithm,
+                const memory::desc &diff_src_desc,
+                const memory::desc &weights_desc,
+                const memory::desc &diff_dst_desc,
+                const memory::dims strides,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(mkldnn_deconvolution_backward_data_desc_init(
+                        &data, convert_to_c(aalgorithm), &diff_src_desc.data,
+                        &weights_desc.data, &diff_dst_desc.data,
+                        &strides[0], &padding_l[0], &padding_r[0],
+                        mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a deconvolution backward data descriptor");
+        }
+        desc(algorithm aalgorithm,
+                const memory::desc &diff_src_desc,
+                const memory::desc &weights_desc,
+                const memory::desc &diff_dst_desc,
+                const memory::dims strides,
+                const memory::dims dilates,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(dilates);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(mkldnn_dilated_deconvolution_backward_data_desc_init(
+                        &data, convert_to_c(aalgorithm), &diff_src_desc.data,
+                        &weights_desc.data, &diff_dst_desc.data,
+                        &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
+                        mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a dilated deconvolution backward data descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e,
+                const deconvolution_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
+                const deconvolution_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
+
+        REG_QUERY_MD(diff_src, diff_src, 0);
+        REG_QUERY_MD(weights, weights, 0);
+        REG_QUERY_MD(diff_dst, diff_dst, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    deconvolution_backward_data(const primitive_desc &pd): primitive(pd) {}
+};
+
+struct deconvolution_backward_weights : public primitive {
+    struct desc {
+        mkldnn_deconvolution_desc_t data;
+        desc(algorithm aalgorithm,
+                const memory::desc &src_desc,
+                const memory::desc &diff_weights_desc,
+                const memory::desc &diff_bias_desc,
+                const memory::desc &diff_dst_desc,
+                const memory::dims strides,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init(
+                        &data, convert_to_c(aalgorithm), &src_desc.data,
+                        &diff_weights_desc.data, &diff_bias_desc.data,
+                        &diff_dst_desc.data,
+                        &strides[0], &padding_l[0], &padding_r[0],
+                        mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a deconvolution backward weights descriptor");
+        }
+        desc(algorithm aalgorithm,
+                const memory::desc &src_desc,
+                const memory::desc &diff_weights_desc,
+                const memory::desc &diff_dst_desc,
+                const memory::dims strides,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(mkldnn_deconvolution_backward_weights_desc_init(
+                        &data, convert_to_c(aalgorithm), &src_desc.data,
+                        &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
+                        &strides[0], &padding_l[0], &padding_r[0],
+                        mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a deconvolution backward weights descriptor");
+        }
+        desc(algorithm aalgorithm,
+                const memory::desc &src_desc,
+                const memory::desc &diff_weights_desc,
+                const memory::desc &diff_bias_desc,
+                const memory::desc &diff_dst_desc,
+                const memory::dims strides,
+                const memory::dims dilates,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(dilates);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init(
+                        &data, convert_to_c(aalgorithm), &src_desc.data,
+                        &diff_weights_desc.data, &diff_bias_desc.data,
+                        &diff_dst_desc.data,
+                        &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
+                        mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a dilated  deconvolution backward weights descriptor");
+        }
+        desc(algorithm aalgorithm,
+                const memory::desc &src_desc,
+                const memory::desc &diff_weights_desc,
+                const memory::desc &diff_dst_desc,
+                const memory::dims strides,
+                const memory::dims dilates,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(dilates);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(mkldnn_dilated_deconvolution_backward_weights_desc_init(
+                        &data, convert_to_c(aalgorithm), &src_desc.data,
+                        &diff_weights_desc.data, nullptr, &diff_dst_desc.data,
+                        &strides[0], &dilates[0], &padding_l[0], &padding_r[0],
+                        mkldnn::convert_to_c(apadding_kind)),
+                    "could not create a dilated deconvolution backward weights descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e,
+                const deconvolution_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
+                const deconvolution_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
+
+        REG_QUERY_MD(src, src, 0);
+        REG_QUERY_MD(diff_weights, diff_weights, 0);
+        REG_QUERY_MD(diff_bias, diff_weights, 1);
+        REG_QUERY_MD(diff_dst, diff_dst, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    deconvolution_backward_weights(const primitive_desc &pd): primitive(pd) {}
+};
+
+/// @}
+
+/// @addtogroup cpp_api_lrn LRN
+/// A primitive to perform local response normalization (LRN) across or within
+/// channels.
+///
+/// @sa @ref c_api_lrn in @ref c_api
+/// @{
+
+struct lrn_forward : public primitive {
+    struct desc {
+        mkldnn_lrn_desc_t data;
+
+        desc(prop_kind aprop_kind, algorithm aalgorithm,
+                const memory::desc &src_desc, memory::dim local_size,
+                float alpha, float beta, float k = 1.f) {
+            error::wrap_c_api(mkldnn_lrn_forward_desc_init(&data,
+                mkldnn::convert_to_c(aprop_kind), convert_to_c(aalgorithm),
+                &src_desc.data, local_size, alpha, beta, k),
+                "could not create a lrn forward descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
+
+        REG_QUERY_MD(src, src, 0);
+        REG_QUERY_MD(dst, dst, 0);
+        REG_QUERY_MD(workspace, workspace, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    lrn_forward(const primitive_desc &pd): primitive(pd) {}
+};
+
+struct lrn_backward : public primitive {
+    struct desc {
+        mkldnn_lrn_desc_t data;
+
+        desc(algorithm aalgorithm, const memory::desc &data_desc,
+                const memory::desc &diff_data_desc, memory::dim local_size,
+                float alpha, float beta, float k = 1.f) {
+            error::wrap_c_api(mkldnn_lrn_backward_desc_init(&data,
+                convert_to_c(aalgorithm), &diff_data_desc.data,
+                &data_desc.data, local_size, alpha, beta, k),
+                "could not create a lrn backward descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e,
+                const lrn_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
+                const lrn_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
+
+        REG_QUERY_MD(diff_src, diff_src, 0);
+        REG_QUERY_MD(diff_dst, diff_dst, 0);
+        REG_QUERY_MD(workspace, workspace, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    lrn_backward(const primitive_desc &pd): primitive(pd) {}
+};
+
+/// @}
+
+/// @addtogroup cpp_api_pooling Pooling
+/// A primitive to perform max or average pooling.
+///
+/// @sa @ref c_api_pooling in @ref c_api
+/// @{
+
+struct pooling_forward : public primitive {
+    struct desc {
+        mkldnn_pooling_desc_t data;
+        desc(prop_kind aprop_kind, algorithm aalgorithm,
+                const memory::desc &src_desc,
+                const memory::desc &dst_desc,
+                const memory::dims strides,
+                const memory::dims kernel,
+                const memory::dims padding_l,
+                const memory::dims padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(kernel);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(mkldnn_pooling_forward_desc_init(&data,
+                        mkldnn::convert_to_c(aprop_kind),
+                        convert_to_c(aalgorithm),
+                        &src_desc.data, &dst_desc.data,
+                        &strides[0], &kernel[0],
+                        &padding_l[0], &padding_r[0],
+                        mkldnn::convert_to_c(apadding_kind)),
+                    "could not init a forward pooling descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
+
+        REG_QUERY_MD(src, src, 0);
+        REG_QUERY_MD(dst, dst, 0);
+        REG_QUERY_MD(workspace, workspace, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    pooling_forward(const primitive_desc &pd): primitive(pd) {}
+};
+
+struct pooling_backward : public primitive {
+    struct desc {
+        mkldnn_pooling_desc_t data;
+        desc(algorithm aalgorithm,
+                const memory::desc &diff_src_desc,
+                const memory::desc &diff_dst_desc,
+                const memory::dims &strides,
+                const memory::dims &kernel,
+                const memory::dims &padding_l,
+                const memory::dims &padding_r,
+                const padding_kind apadding_kind) {
+            memory::validate_dims(strides);
+            memory::validate_dims(kernel);
+            memory::validate_dims(padding_l);
+            memory::validate_dims(padding_r);
+            error::wrap_c_api(mkldnn_pooling_backward_desc_init(&data,
+                        convert_to_c(aalgorithm),
+                        &diff_src_desc.data, &diff_dst_desc.data,
+                        &strides[0], &kernel[0],
+                        &padding_l[0], &padding_r[0],
+                        mkldnn::convert_to_c(apadding_kind)),
+                    "could not init a backward pooling descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e,
+                const pooling_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
+                const pooling_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
+
+        REG_QUERY_MD(diff_src, diff_src, 0);
+        REG_QUERY_MD(diff_dst, diff_dst, 0);
+        REG_QUERY_MD(workspace, workspace, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    pooling_backward(const primitive_desc &pd): primitive(pd) {}
+};
+
+/// @}
+
+/// @addtogroup cpp_api_eltwise Eltwise
+/// A primitive to compute element-wise operations like parametric rectifier
+/// linear unit (ReLU).
+///
+/// @sa @ref c_api_eltwise in @ref c_api
+/// @{
+
+struct eltwise_forward : public primitive {
+    struct desc {
+        mkldnn_eltwise_desc_t data;
+        template <typename T>
+        desc(prop_kind aprop_kind, algorithm alg_kind,
+                const memory::desc &src_desc, T alpha = 0, T beta = 0) {
+            error::wrap_c_api(mkldnn_eltwise_forward_desc_init(&data,
+                        mkldnn::convert_to_c(aprop_kind),
+                        mkldnn::convert_to_c(alg_kind), &src_desc.data,
+                        static_cast<float>(alpha), static_cast<float>(beta)),
+                    "could not create a eltwise forward descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
+
+        REG_QUERY_MD(src, src, 0);
+        REG_QUERY_MD(dst, dst, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    eltwise_forward(const primitive_desc &pd): primitive(pd) {}
+};
+
+struct eltwise_backward : public primitive {
+    struct desc {
+        mkldnn_eltwise_desc_t data;
+
+        template <typename T>
+        desc(algorithm alg_kind, const memory::desc &diff_data_desc,
+                const memory::desc &data_desc, T alpha = 0, T beta = 0) {
+            error::wrap_c_api(mkldnn_eltwise_backward_desc_init(&data,
+                        mkldnn::convert_to_c(alg_kind), &diff_data_desc.data,
+                        &data_desc.data, static_cast<float>(alpha),
+                        static_cast<float>(beta)),
+                    "could not create a eltwise backward descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e,
+                const eltwise_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
+                const eltwise_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
+
+        REG_QUERY_MD(src, src, 0);
+        REG_QUERY_MD(diff_src, diff_src, 0);
+        REG_QUERY_MD(diff_dst, diff_dst, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    eltwise_backward(const primitive_desc &pd): primitive(pd) {}
+};
+
+/// @}
+
+/// @addtogroup cpp_api_softmax Softmax
+/// A primitive to perform softmax.
+///
+/// @sa @ref c_api_softmax in @ref c_api
+/// @{
+
+struct softmax_forward : public primitive {
+    struct desc {
+        mkldnn_softmax_desc_t data;
+        desc(prop_kind aprop_kind, const memory::desc &data_desc,
+             int softmax_axis) {
+            error::wrap_c_api(mkldnn_softmax_forward_desc_init(&data,
+                    mkldnn::convert_to_c(aprop_kind), &data_desc.data,
+                    softmax_axis),
+                "could not create a softmax forward descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
+
+        REG_QUERY_MD(src, src, 0);
+        REG_QUERY_MD(dst, dst, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    softmax_forward(const primitive_desc &pd): primitive(pd) {}
+};
+
+struct softmax_backward : public primitive {
+    struct desc {
+        mkldnn_softmax_desc_t data;
+        desc(const memory::desc &diff_desc, const memory::desc &data_desc,
+                int softmax_axis) {
+            error::wrap_c_api(mkldnn_softmax_backward_desc_init(&data,
+                        &diff_desc.data, &data_desc.data, softmax_axis),
+                    "could not init a backward softmax descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e,
+                const softmax_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
+                const softmax_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
+
+        REG_QUERY_MD(dst, dst, 0);
+        REG_QUERY_MD(diff_src, diff_src, 0);
+        REG_QUERY_MD(diff_dst, diff_dst, 0);
+        REG_QUERY_MD(workspace, workspace, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    softmax_backward(const primitive_desc &pd): primitive(pd) {}
+};
+
+/// @}
+
+/// @addtogroup cpp_api_batch_norm Batch normalization
+/// A primitive to perform batch normalization.
+///
+/// @sa @ref c_api_batch_normalization in @ref c_api
+/// @{
+
+struct batch_normalization_forward : public primitive {
+    struct desc {
+        mkldnn_batch_normalization_desc_t data;
+        template <typename T>
+        desc(prop_kind aprop_kind, const memory::desc &src_desc, T epsilon,
+                unsigned flags) {
+            error::wrap_c_api(
+                    mkldnn_batch_normalization_forward_desc_init(&data,
+                        mkldnn::convert_to_c(aprop_kind), &src_desc.data,
+                        static_cast<float>(epsilon), flags),
+                "could not create a batch normalization forward descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
+
+        REG_QUERY_MD(src, src, 0);
+        REG_QUERY_MD(weights, weights, 0);
+        REG_QUERY_MD(dst, dst, 0);
+        REG_QUERY_MD(workspace, workspace, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+
+        memory::desc mean_desc() const { return stat_desc(mean); }
+        memory::desc variance_desc() const { return stat_desc(var); }
+
+    private:
+        enum { mean = 1, var = 2, };
+        memory::desc stat_desc(int kind) const {
+            mkldnn_batch_normalization_desc_t *p;
+            error::wrap_c_api(mkldnn_primitive_desc_query(
+                    get(), mkldnn::convert_to_c(batch_normalization_d), 0, &p),
+                    "could not get a batch-normalization descriptor");
+            return query_md(p->flags & use_global_stats ? src_md : dst_md, kind);
+        }
+    };
+
+    batch_normalization_forward(const primitive_desc &pd): primitive(pd) {}
+};
+
+struct batch_normalization_backward : public primitive {
+    struct desc {
+        mkldnn_batch_normalization_desc_t data;
+        template <typename T>
+        desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
+                const memory::desc &data_desc, T epsilon, unsigned flags) {
+            error::wrap_c_api(
+                    mkldnn_batch_normalization_backward_desc_init(&data,
+                        mkldnn::convert_to_c(aprop_kind),
+                        &diff_data_desc.data, &data_desc.data,
+                        static_cast<float>(epsilon), flags),
+                "could not create a batch normalization backward descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e,
+                const batch_normalization_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
+                const batch_normalization_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
+
+        REG_QUERY_MD(src, src, 0);
+        REG_QUERY_MD(mean, src, 1);
+        REG_QUERY_MD(variance, src, 2);
+        REG_QUERY_MD(weights, weights, 0);
+        REG_QUERY_MD(dst, dst, 0);
+        REG_QUERY_MD(diff_dst, diff_dst, 0);
+        REG_QUERY_MD(workspace, workspace, 0);
+
+        REG_QUERY_MD(diff_src, diff_src, 0);
+        REG_QUERY_MD(diff_weights, diff_weights, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    batch_normalization_backward(const primitive_desc &pd): primitive(pd) {}
+};
+
+/// @}
+
+/// @addtogroup cpp_api_inner_product Inner Product
+/// A primitive to compute an inner product.
+///
+/// @sa @ref c_api_inner_product in @ref c_api
+/// @{
+
+struct inner_product_forward: public primitive {
+    struct desc {
+        mkldnn_inner_product_desc_t data;
+        desc(prop_kind aprop_kind, const memory::desc &src_desc,
+                const memory::desc &weights_desc,
+                const memory::desc &bias_desc,
+                const memory::desc &dst_desc) {
+            error::wrap_c_api(
+                    mkldnn_inner_product_forward_desc_init(&data,
+                        mkldnn::convert_to_c(aprop_kind), &src_desc.data,
+                        &weights_desc.data, &bias_desc.data, &dst_desc.data),
+                    "could not create a inner product forward descriptor");
+        }
+
+        desc(prop_kind aprop_kind, const memory::desc &src_desc,
+                const memory::desc &weights_desc,
+                const memory::desc &dst_desc) {
+            error::wrap_c_api(
+                    mkldnn_inner_product_forward_desc_init(&data,
+                        mkldnn::convert_to_c(aprop_kind), &src_desc.data,
+                        &weights_desc.data, nullptr, &dst_desc.data),
+                    "could not create a inner product forward descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
+
+        REG_QUERY_MD(src, src, 0);
+        REG_QUERY_MD(weights, weights, 0);
+        REG_QUERY_MD(bias, weights, 1);
+        REG_QUERY_MD(dst, dst, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    inner_product_forward(const primitive_desc &pd): primitive(pd) {}
+};
+
+struct inner_product_backward_data: public primitive {
+    struct desc {
+        mkldnn_inner_product_desc_t data;
+        desc(const memory::desc &diff_src_desc,
+                const memory::desc &weights_desc,
+                const memory::desc &diff_dst_desc) {
+            error::wrap_c_api(
+                    mkldnn_inner_product_backward_data_desc_init(&data,
+                        &diff_src_desc.data, &weights_desc.data,
+                        &diff_dst_desc.data),
+                "could not create a inner product backward data descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e,
+                const inner_product_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
+                const inner_product_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
+
+        REG_QUERY_MD(diff_src, diff_src, 0);
+        REG_QUERY_MD(weights, weights, 0);
+        REG_QUERY_MD(diff_dst, diff_dst, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    inner_product_backward_data(const primitive_desc &pd): primitive(pd) {}
+};
+
+struct inner_product_backward_weights: public primitive {
+    struct desc {
+        mkldnn_inner_product_desc_t data;
+        desc(const memory::desc &src_desc,
+                const memory::desc &diff_weights_desc,
+                const memory::desc &diff_bias_desc,
+                const memory::desc &diff_dst_desc) {
+            error::wrap_c_api(
+                    mkldnn_inner_product_backward_weights_desc_init(
+                        &data, &src_desc.data, &diff_weights_desc.data,
+                        &diff_bias_desc.data, &diff_dst_desc.data),
+                "could not create a inner product backward weights descriptor");
+        }
+        desc(const memory::desc &src_desc,
+                const memory::desc &diff_weights_desc,
+                const memory::desc &diff_dst_desc) {
+            error::wrap_c_api(
+                    mkldnn_inner_product_backward_weights_desc_init(
+                        &data, &src_desc.data, &diff_weights_desc.data,
+                        nullptr, &diff_dst_desc.data),
+                "could not create a inner product backward weights descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e,
+                const inner_product_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
+                const inner_product_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
+
+        REG_QUERY_MD(src, src, 0);
+        REG_QUERY_MD(diff_weights, diff_weights, 0);
+        REG_QUERY_MD(diff_bias, diff_weights, 1);
+        REG_QUERY_MD(diff_dst, diff_dst, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    inner_product_backward_weights(const primitive_desc &pd): primitive(pd) {}
+};
+
+/// @}
+
+/// @addtogroup cpp_api_rnn RNN
+/// A primitive to compute common recurrent layer.
+///
+/// @sa @ref c_api_rnn in @ref c_api
+/// @{
+
+struct rnn_cell {
+    struct desc {
+        mkldnn_rnn_cell_desc_t c_rnn_cell_;
+
+        desc(algorithm kind, algorithm activation_f) {
+            error::wrap_c_api(mkldnn_rnn_cell_desc_init(&c_rnn_cell_,
+                        mkldnn::convert_to_c(kind),
+                        mkldnn::convert_to_c(activation_f), 0U, 0, 0),
+                    "could not init an rnn cell descriptor");
+        }
+        desc(algorithm kind): desc(kind, algorithm::algorithm_undef) {}
+
+        operator const mkldnn_rnn_cell_desc_t*() const { return &c_rnn_cell_; }
+
+        algorithm get_cell_kind() const
+        { return algorithm(c_rnn_cell_.cell_kind); }
+        algorithm get_activation() const
+        { return algorithm(c_rnn_cell_.activation_kind); }
+
+        float get_alpha() const { return c_rnn_cell_.alpha; }
+        void set_alpha(float alpha) {
+            c_rnn_cell_.flags |= mkldnn_rnn_cell_with_relu;
+            c_rnn_cell_.alpha = alpha;
+        }
+
+        float get_clipping() const { return c_rnn_cell_.clipping; }
+        void set_clipping(float clipping) {
+            c_rnn_cell_.flags |= mkldnn_rnn_cell_with_clipping;
+            c_rnn_cell_.clipping = clipping;
+        }
+
+        int get_gates_count() const {
+            return mkldnn_rnn_cell_get_gates_count(&c_rnn_cell_);
+        }
+        int get_state_count() const {
+            return mkldnn_rnn_cell_get_states_count(&c_rnn_cell_);
+        }
+    };
+};
+
+struct rnn_forward : public primitive {
+    struct desc {
+        mkldnn_rnn_desc_t data;
+        desc(prop_kind aprop_kind, rnn_cell::desc cell,
+                const rnn_direction direction,
+                const memory::desc &src_layer_desc,
+                const memory::desc &src_iter_desc,
+                const memory::desc &weights_layer_desc,
+                const memory::desc &weights_iter_desc,
+                const memory::desc &bias_desc,
+                const memory::desc &dst_layer_desc,
+                const memory::desc &dst_iter_desc
+            ) {
+            error::wrap_c_api(mkldnn_rnn_forward_desc_init(&data,
+                        mkldnn::convert_to_c(aprop_kind), cell,
+                        mkldnn::convert_to_c(direction),
+                        &src_layer_desc.data, &src_iter_desc.data,
+                        &weights_layer_desc.data, &weights_iter_desc.data,
+                        &bias_desc.data,
+                        &dst_layer_desc.data, &dst_iter_desc.data),
+                    "could not create an RNN forward descriptor");
+        }
+
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, nullptr) {}
+
+        REG_QUERY_MD(src_layer, src, 0);
+        REG_QUERY_MD(src_iter, src, 1);
+        REG_QUERY_MD(weights_layer, weights, 0);
+        REG_QUERY_MD(weights_iter, weights, 1);
+        REG_QUERY_MD(bias, weights, 2);
+        REG_QUERY_MD(dst_layer, dst, 0);
+        REG_QUERY_MD(dst_iter, dst, 1);
+        REG_QUERY_MD(workspace, workspace, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    rnn_forward(const primitive_desc &pd): primitive(pd) {}
+};
+
+struct rnn_backward : public primitive {
+    struct desc {
+        mkldnn_rnn_desc_t data;
+        desc(prop_kind aprop_kind, rnn_cell::desc cell,
+                const rnn_direction direction,
+                const memory::desc &src_layer_desc,
+                const memory::desc &src_iter_desc,
+                const memory::desc &weights_layer_desc,
+                const memory::desc &weights_iter_desc,
+                const memory::desc &bias_desc,
+                const memory::desc &dst_layer_desc,
+                const memory::desc &dst_iter_desc,
+                const memory::desc &diff_src_layer_desc,
+                const memory::desc &diff_src_iter_desc,
+                const memory::desc &diff_weights_layer_desc,
+                const memory::desc &diff_weights_iter_desc,
+                const memory::desc &diff_bias_desc,
+                const memory::desc &diff_dst_layer_desc,
+                const memory::desc &diff_dst_iter_desc) {
+            error::wrap_c_api(mkldnn_rnn_backward_desc_init(&data,
+                        mkldnn::convert_to_c(aprop_kind), cell,
+                        mkldnn::convert_to_c(direction),
+                        &src_layer_desc.data, &src_iter_desc.data,
+                        &weights_layer_desc.data, &weights_iter_desc.data,
+                        &bias_desc.data,
+                        &dst_layer_desc.data, &dst_iter_desc.data,
+                        &diff_src_layer_desc.data, &diff_src_iter_desc.data,
+                        &diff_weights_layer_desc.data,
+                        &diff_weights_iter_desc.data, &diff_bias_desc.data,
+                        &diff_dst_layer_desc.data, &diff_dst_iter_desc.data),
+                    "could not create an RNN backward descriptor");
+        }
+
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e,
+                const rnn_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
+
+        primitive_desc(const desc &desc, const primitive_attr &attr, const engine &e,
+                const rnn_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, &attr, e, hint_fwd_pd.get()) {}
+
+        REG_QUERY_MD(src_layer, src, 0);
+        REG_QUERY_MD(src_iter, src, 1);
+        REG_QUERY_MD(weights_layer, weights, 0);
+        REG_QUERY_MD(weights_iter, weights, 1);
+        REG_QUERY_MD(bias, weights, 2);
+        REG_QUERY_MD(dst_layer, dst, 0);
+        REG_QUERY_MD(dst_iter, dst, 1);
+        REG_QUERY_MD(workspace, workspace, 0);
+
+        REG_QUERY_MD(diff_src_layer, diff_src, 0);
+        REG_QUERY_MD(diff_src_iter, diff_src, 1);
+        REG_QUERY_MD(diff_weights_layer, diff_weights, 0);
+        REG_QUERY_MD(diff_weights_iter, diff_weights, 1);
+        REG_QUERY_MD(diff_bias, diff_weights, 2);
+        REG_QUERY_MD(diff_dst_layer, diff_dst, 0);
+        REG_QUERY_MD(diff_dst_iter, diff_dst, 1);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    // With last iteration (with and without input src_iter)
+    rnn_backward(const primitive_desc &pd): primitive(pd) {}
+};
+
+/// @}
+
+/// @addtogroup cpp_api_shuffle Shuffle
+/// A primitive to shuffle data along the axis.
+///
+/// @sa @ref c_api_shuffle in @ref c_api
+/// @{
+
+struct shuffle_forward : public primitive {
+    struct desc {
+        mkldnn_shuffle_desc_t data;
+        desc(prop_kind aprop_kind, const memory::desc &data_desc,
+                int axis, int group_size) {
+            error::wrap_c_api(mkldnn_shuffle_forward_desc_init(&data,
+                        mkldnn::convert_to_c(aprop_kind), &data_desc.data,
+                        axis, group_size),
+                    "could not create a shuffle forward descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, nullptr) {}
+
+        REG_QUERY_MD(src, src, 0);
+        REG_QUERY_MD(dst, dst, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    shuffle_forward(const primitive_desc &pd): primitive(pd) {}
+};
+
+struct shuffle_backward : public primitive {
+    struct desc {
+        mkldnn_shuffle_desc_t data;
+        desc(const memory::desc &diff_data_desc, int axis, int group_size) {
+            error::wrap_c_api(mkldnn_shuffle_backward_desc_init(&data,
+                        &diff_data_desc.data, axis, group_size),
+                    "could not create a shuffle backward descriptor");
+        }
+    };
+
+    struct primitive_desc : public mkldnn::primitive_desc {
+        primitive_desc(const desc &desc, const engine &e,
+                const shuffle_forward::primitive_desc &hint_fwd_pd)
+            : mkldnn::primitive_desc(&desc.data, nullptr, e, hint_fwd_pd.get()) {}
+
+        REG_QUERY_MD(diff_src, diff_src, 0);
+        REG_QUERY_MD(diff_dst, diff_dst, 0);
+        REG_QUERY_MD(scratchpad, scratchpad, 0);
+    };
+
+    shuffle_backward(const primitive_desc &pd): primitive(pd) {}
+};
+
+/// @}
+
+/// @} Primitives
+
+/// @} C++ API
+
+#undef REG_QUERY_MD
+
+// implementation section
+#ifndef DOXYGEN_SHOULD_SKIP_THIS
+
+inline primitive::primitive(const_mkldnn_primitive_desc_t c_pd) {
+    mkldnn_primitive_t result;
+    error::wrap_c_api(mkldnn_primitive_create(&result, c_pd),
+            "could not create a primitive");
+    reset(result);
+}
+
+inline primitive::primitive(const primitive_desc &pd): primitive(pd.get()) {}
+
+inline void primitive::execute(stream &astream,
+        const std::unordered_map<int, memory> &args) const {
+    std::vector<mkldnn_exec_arg_t> c_args;
+    c_args.reserve(args.size());
+    for (const auto &a: args)
+        c_args.push_back({a.first, a.second.get()});
+
+    error::wrap_c_api(mkldnn_primitive_execute(get(), astream.get(),
+                (int)c_args.size(), c_args.data()),
+            "primitive execution fail");
+}
+#endif // DOXYGEN_SHOULD_SKIP_THIS
+
+} // namespace mkldnn
+
+#endif

+ 98 - 0
thirdparty/oidn/mkl-dnn/include/mkldnn_debug.h

@@ -0,0 +1,98 @@
+/*******************************************************************************
+* Copyright 2018-2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/* DO NOT EDIT, AUTO-GENERATED */
+
+#ifndef MKLDNN_DEBUG_H
+#define MKLDNN_DEBUG_H
+
+#ifndef DOXYGEN_SHOULD_SKIP_THIS
+
+/* All symbols shall be internal unless marked as MKLDNN_API */
+#if defined _WIN32 || defined __CYGWIN__
+#   define MKLDNN_HELPER_DLL_IMPORT __declspec(dllimport)
+#   define MKLDNN_HELPER_DLL_EXPORT __declspec(dllexport)
+#else
+#   if __GNUC__ >= 4
+#       define MKLDNN_HELPER_DLL_IMPORT __attribute__ ((visibility ("default")))
+#       define MKLDNN_HELPER_DLL_EXPORT __attribute__ ((visibility ("default")))
+#   else
+#       define MKLDNN_HELPER_DLL_IMPORT
+#       define MKLDNN_HELPER_DLL_EXPORT
+#   endif
+#endif
+
+#ifdef MKLDNN_DLL
+#   ifdef MKLDNN_DLL_EXPORTS
+#       define MKLDNN_API MKLDNN_HELPER_DLL_EXPORT
+#   else
+#       define MKLDNN_API MKLDNN_HELPER_DLL_IMPORT
+#   endif
+#else
+#   define MKLDNN_API
+#endif
+
+#if defined (__GNUC__)
+#   define MKLDNN_DEPRECATED __attribute__((deprecated))
+#elif defined(_MSC_VER)
+#   define MKLDNN_DEPRECATED __declspec(deprecated)
+#else
+#   define MKLDNN_DEPRECATED
+#endif
+
+#include "mkldnn_types.h"
+#endif /* DOXYGEN_SHOULD_SKIP_THIS */
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+const char MKLDNN_API *mkldnn_status2str(mkldnn_status_t v);
+const char MKLDNN_API *mkldnn_dt2str(mkldnn_data_type_t v);
+const char MKLDNN_API *mkldnn_fmt_kind2str(mkldnn_format_kind_t v);
+const char MKLDNN_API *mkldnn_fmt_tag2str(mkldnn_format_tag_t v);
+const char MKLDNN_API *mkldnn_prop_kind2str(mkldnn_prop_kind_t v);
+const char MKLDNN_API *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v);
+const char MKLDNN_API *mkldnn_alg_kind2str(mkldnn_alg_kind_t v);
+const char MKLDNN_API *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v);
+
+/** Forms a format string for a given memory descriptor.
+ *
+ * The format is defined as: 'dt:[p|o|0]:fmt_kind:fmt:extra'.
+ * Here:
+ *  - dt       -- data type
+ *  - p        -- indicates there is non-trivial padding
+ *  - o        -- indicates there is non-trivial padding offset
+ *  - 0        -- indicates there is non-trivial offset0
+ *  - fmt_kind -- format kind (blocked, wino, etc...)
+ *  - fmt      -- extended format string (format_kind specific)
+ *  - extra    -- shows extra fields (underspecified)
+ */
+int MKLDNN_API mkldnn_md2fmt_str(char *fmt_str, size_t fmt_str_len,
+        const mkldnn_memory_desc_t *md);
+
+/** Forms a dimension string for a given memory descriptor.
+ *
+ * The format is defined as: 'dim0xdim1x...xdimN
+ */
+int MKLDNN_API mkldnn_md2dim_str(char *dim_str, size_t dim_str_len,
+        const mkldnn_memory_desc_t *md);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif

+ 1415 - 0
thirdparty/oidn/mkl-dnn/include/mkldnn_types.h

@@ -0,0 +1,1415 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MKLDNN_TYPES_H
+#define MKLDNN_TYPES_H
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+#ifndef DOXYGEN_SHOULD_SKIP_THIS
+#include <stddef.h>
+#include <stdint.h>
+#endif
+
+/** @addtogroup c_api C API
+ *  @{
+ *
+ *  @addtogroup c_api_types Types
+ *  @{
+ *
+ *  @addtogroup c_api_types_generic Generic
+ *  @{ */
+
+/** Intel(R) MKL-DNN Version type */
+typedef struct {
+    int    major;
+    int    minor;
+    int    patch;
+    const char *hash;
+} mkldnn_version_t;
+
+/** Status values returned by Intel(R) MKL-DNN functions. */
+typedef enum {
+    /** The operation was successful */
+    mkldnn_success = 0,
+    /** The operation failed due to an out-of-memory condition */
+    mkldnn_out_of_memory = 1,
+    /** The operation failed and should be retried */
+    mkldnn_try_again = 2,
+    /** The operation failed because of incorrect function arguments  */
+    mkldnn_invalid_arguments = 3,
+    /** The operation failed because a primitive was not ready for execution */
+    mkldnn_not_ready = 4,
+    /** The operation failed because requested functionality is not implemented
+     */
+    mkldnn_unimplemented = 5,
+    /** Primitive iterator passed over last primitive descriptor */
+    mkldnn_iterator_ends = 6,
+    /** Primitive or engine failed on execution */
+    mkldnn_runtime_error = 7,
+    /** Queried element is not required for given primitive */
+    mkldnn_not_required = 8,
+} mkldnn_status_t;
+
+/** Data type specification */
+typedef enum {
+    /** Undefined data type, used for empty memory descriptors. */
+    mkldnn_data_type_undef = 0,
+    /** 32-bit/single-precision floating point. */
+    mkldnn_f32 = 1,
+    /** 32-bit signed integer. */
+    mkldnn_s32 = 2,
+    /** 8-bit signed integer. */
+    mkldnn_s8 = 3,
+    /** 8-bit unsigned integer. */
+    mkldnn_u8 = 4,
+} mkldnn_data_type_t;
+
+/** Memory format kind */
+typedef enum {
+    /** Undefined memory format, used for empty memory descriptors. */
+    mkldnn_format_kind_undef = 0,
+    /** Unspecified format. The primitive selects a format automatically. */
+    mkldnn_format_kind_any,
+    /** A tensor in a generic format described by the stride and blocking
+     * values in each dimension. See #mkldnn_blocking_desc_t for more
+     * information. */
+    mkldnn_blocked,
+    /** Weights format used in 8bit Winograd convolution */
+    mkldnn_format_kind_wino,
+    /** Packed weights format used in RNN */
+    mkldnn_format_kind_rnn_packed,
+} mkldnn_format_kind_t;
+
+/** Memory format tag specification.
+ *
+ * Intel MKL-DNN formats describe physical data layout. The physical layout
+ * is described as a sequence of the dimensions as they are laid out in the
+ * memory (from the outer-most to the inner-most). Note that this order
+ * doesn't affect the logical order of the dimensions that is kept in the
+ * `dims` field of the mkldnn_memory_desc_t structure. The logical order of the
+ * dimensions is specified by the type of tensor.
+ *
+ * For example, CNN 5D tensor always has its logical dimensions in the order
+ * `(batch, channels, depth, height, width)`, while the physical layout might be
+ * #mkldnn_ncdhw or #mkldnn_ndhwc:
+ *
+ * ~~~cpp
+ * int batch = 2, channels = 16, depth = 13, height = 13, width = 13;
+ *
+ * int ndims = 5; // 5D tensor
+ * mkldnn_dims_t dims = {batch, channels, depth, height, width};
+ * mkldnn_memory_desc_t data_in_ncdhw;
+ * mkldnn_memory_desc_init_by_tag(
+ *      &data_in_ncdhw, 5, dims, mkldnn_f32, mkldnn_ncdhw);
+ *
+ * // note that in both cases dims passed are the same
+ * mkldnn_memory_desc_t data_in_ndhwc;
+ * mkldnn_memory_desc_init_by_tag(
+ *      &data_in_ndhwc, 5, dims, mkldnn_f32, mkldnn_ndhwc);
+ * ~~~
+ *
+ * The following notation applies to memory format names:
+ *  - @c 'n' denotes the mini-batch dimension
+ *  - @c 'c' denotes a channels dimension
+ *  - When there are multiple channel dimensions (for example, in convolution
+ *    weights tensor), @c 'i' and @c 'o' denote dimensions of input and output
+ *    channels
+ *  - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width
+ *    respectively
+ *  - Upper-case letters indicate that the data is laid out in blocks
+ *    for a particular dimension. In such cases, the format name contains both
+ *    upper- and lower-case letters for that dimension with a lower-case letter
+ *    preceded by the block size. For example: @c 'mkldnn_nChw8c' describes a
+ *    format where the outermost dimension is mini-batch, followed by the
+ *    channel block number, followed by the spatial height and width, and
+ *    finally followed by 8-element channel blocks.
+ *
+ * @note
+ *    Channel designations can be different. For example, both the @c
+ *    'mkldnn_nc' and @c 'mkldnn_io' formats can be used to describe a 2D
+ *    tensor.
+ *
+ * @sa @ref understanding_memory_formats
+ */
+typedef enum {
+    /** Undefined memory format tag */
+    mkldnn_format_tag_undef = 0,
+    /** Undefined memory format tag.
+     * The primitive selects a format automatically. */
+    mkldnn_format_tag_any,
+
+    /* Semantic agnostic section */
+    /* The physical order of dimensions is defined by the permutation of the
+     * characters, assuming that ab..z defines the natural order.
+     */
+
+    /* Plain formats */
+
+    mkldnn_a,
+    mkldnn_ab,
+    mkldnn_abc,
+    mkldnn_abcd,
+    mkldnn_abcde,
+    mkldnn_abcdef,
+    mkldnn_abdec,
+    mkldnn_acb,
+    mkldnn_acbde,
+    mkldnn_acdb,
+    mkldnn_acdeb,
+    mkldnn_ba,
+    mkldnn_bac,
+    mkldnn_bacd,
+    mkldnn_bcda,
+    mkldnn_cba,
+    mkldnn_cdba,
+    mkldnn_cdeba,
+    mkldnn_decab,
+
+    /* Opaque blocked formats */
+
+    mkldnn_Abc16a,
+    mkldnn_ABc16a16b,
+    mkldnn_aBc16b,
+    mkldnn_ABc16b16a,
+    mkldnn_Abc4a,
+    mkldnn_aBc4b,
+    mkldnn_ABc4b16a4b,
+    mkldnn_ABc4b4a,
+    mkldnn_ABc8a16b2a,
+    mkldnn_ABc8a8b,
+    mkldnn_aBc8b,
+    mkldnn_ABc8b16a2b,
+    mkldnn_ABc8b8a,
+    mkldnn_Abcd16a,
+    mkldnn_ABcd16a16b,
+    mkldnn_aBcd16b,
+    mkldnn_ABcd16b16a,
+    mkldnn_aBCd16b16c,
+    mkldnn_aBCd16c16b,
+    mkldnn_Abcd4a,
+    mkldnn_aBcd4b,
+    mkldnn_ABcd4b16a4b,
+    mkldnn_ABcd4b4a,
+    mkldnn_aBCd4c16b4c,
+    mkldnn_aBCd4c4b,
+    mkldnn_ABcd8a16b2a,
+    mkldnn_ABcd8a8b,
+    mkldnn_aBcd8b,
+    mkldnn_ABcd8b16a2b,
+    mkldnn_aBCd8b16c2b,
+    mkldnn_ABcd8b8a,
+    mkldnn_aBCd8b8c,
+    mkldnn_aBCd8c16b2c,
+    mkldnn_aBCd8c8b,
+    mkldnn_Abcde16a,
+    mkldnn_ABcde16a16b,
+    mkldnn_aBcde16b,
+    mkldnn_ABcde16b16a,
+    mkldnn_aBCde16b16c,
+    mkldnn_aBCde16c16b,
+    mkldnn_aBCde2c8b4c,
+    mkldnn_Abcde4a,
+    mkldnn_aBcde4b,
+    mkldnn_ABcde4b4a,
+    mkldnn_aBCde4b4c,
+    mkldnn_aBCde4c16b4c,
+    mkldnn_aBCde4c4b,
+    mkldnn_Abcde8a,
+    mkldnn_ABcde8a8b,
+    mkldnn_aBcde8b,
+    mkldnn_ABcde8b16a2b,
+    mkldnn_aBCde8b16c2b,
+    mkldnn_ABcde8b8a,
+    mkldnn_aBCde8b8c,
+    mkldnn_aBCde8c16b2c,
+    mkldnn_aBCde8c8b,
+    mkldnn_aBcdef16b,
+    mkldnn_aBCdef16b16c,
+    mkldnn_aBCdef16c16b,
+    mkldnn_aBcdef4b,
+    mkldnn_aBCdef4c4b,
+    mkldnn_aBCdef8b8c,
+    mkldnn_aBCdef8c16b2c,
+    mkldnn_aBCdef8c8b,
+    mkldnn_aBdc16b,
+    mkldnn_aBdc4b,
+    mkldnn_aBdc8b,
+    mkldnn_aBdec16b,
+    mkldnn_aBdec4b,
+    mkldnn_aBdec8b,
+    mkldnn_aBdefc16b,
+    mkldnn_aBdefc4b,
+    mkldnn_aBdefc8b,
+    mkldnn_Acb16a,
+    mkldnn_Acb4a,
+    mkldnn_Acb8a,
+    mkldnn_aCBd16b16c,
+    mkldnn_aCBde16b16c,
+    mkldnn_Acdb16a,
+    mkldnn_Acdb4a,
+    mkldnn_Acdb8a,
+    mkldnn_Acdeb16a,
+    mkldnn_Acdeb4a,
+    mkldnn_Acdeb8a,
+    mkldnn_BAc16a16b,
+    mkldnn_BAcd16a16b,
+
+    /** Just a sentinel, not real memory format tag. Must be changed after new
+     * format tag is added. */
+    mkldnn_format_tag_last,
+
+    /* Aliases */
+
+    mkldnn_x = mkldnn_a,
+    mkldnn_nc = mkldnn_ab,
+    mkldnn_cn = mkldnn_ba,
+    mkldnn_ncw = mkldnn_abc,
+    mkldnn_nwc = mkldnn_acb,
+    mkldnn_nchw = mkldnn_abcd,
+    mkldnn_nhwc = mkldnn_acdb,
+    mkldnn_chwn = mkldnn_bcda,
+    mkldnn_ncdhw = mkldnn_abcde,
+    mkldnn_ndhwc = mkldnn_acdeb,
+
+    mkldnn_oi = mkldnn_ab,
+    mkldnn_io = mkldnn_ba,
+    mkldnn_oiw = mkldnn_abc,
+    mkldnn_wio = mkldnn_cba,
+    mkldnn_oihw = mkldnn_abcd,
+    mkldnn_hwio = mkldnn_cdba,
+    mkldnn_ihwo = mkldnn_bcda,
+    mkldnn_iohw = mkldnn_bacd,
+    mkldnn_oidhw = mkldnn_abcde,
+    mkldnn_dhwio = mkldnn_cdeba,
+    mkldnn_goiw = mkldnn_abcd,
+    mkldnn_goihw = mkldnn_abcde,
+    mkldnn_hwigo = mkldnn_decab,
+    mkldnn_giohw = mkldnn_acbde,
+    mkldnn_goidhw = mkldnn_abcdef,
+
+    /** 3D RNN data tensor in the format (seq_length, batch, input channels). */
+    mkldnn_tnc = mkldnn_abc,
+    /** 3D RNN data tensor in the format (batch, seq_length, input channels). */
+    mkldnn_ntc = mkldnn_bac,
+    /** 5D RNN states tensor in the format (num_layers, num_directions,
+     * num_states, batch, state channels). */
+    mkldnn_ldsnc = mkldnn_abcde,
+    /** 5D RNN weights tensor in the format (num_layers, num_directions,
+     *  input_channels, num_gates, output_channels).
+     *
+     *  - For LSTM cells, the gates order is input, forget, candidate
+     *    and output gate.
+     *  - For GRU cells, the gates order is update, reset and output gate. */
+    mkldnn_ldigo = mkldnn_abcde,
+    /** 5D RNN weights tensor in the format (num_layers, num_directions,
+     * num_gates, output_channels, input_channels).
+     *
+     *  - For LSTM cells, the gates order is input, forget, candidate
+     *    and output gate.
+     *  - For GRU cells, the gates order is update, reset and output gate. */
+    mkldnn_ldgoi = mkldnn_abdec,
+    /** 4D RNN bias tensor in the format (num_layers, num_directions,
+     * num_gates, output_channels).
+     *
+     *  - For LSTM cells, the gates order is input, forget, candidate
+     *    and output gate.
+     *  - For GRU cells, the gates order is update, reset and output gate. */
+    mkldnn_ldgo = mkldnn_abcd,
+
+    /* Opaque data types, are not to be used explicitly */
+
+    /* data */
+    mkldnn_nCdhw16c = mkldnn_aBcde16b,
+    mkldnn_nCdhw4c = mkldnn_aBcde4b,
+    mkldnn_nCdhw8c = mkldnn_aBcde8b,
+    mkldnn_nChw16c = mkldnn_aBcd16b,
+    mkldnn_nChw4c = mkldnn_aBcd4b,
+    mkldnn_nChw8c = mkldnn_aBcd8b,
+    mkldnn_nCw16c = mkldnn_aBc16b,
+    mkldnn_nCw4c = mkldnn_aBc4b,
+    mkldnn_nCw8c = mkldnn_aBc8b,
+
+    /* weights, 3D */
+    mkldnn_IOw16o16i = mkldnn_BAc16a16b,
+    mkldnn_OIw16i16o = mkldnn_ABc16b16a,
+    mkldnn_OIw16o16i = mkldnn_ABc16a16b,
+    mkldnn_Oiw16o = mkldnn_Abc16a,
+    mkldnn_OIw4i16o4i = mkldnn_ABc4b16a4b,
+    mkldnn_OIw4i4o = mkldnn_ABc4b4a,
+    mkldnn_Oiw4o = mkldnn_Abc4a,
+    mkldnn_OIw8i16o2i = mkldnn_ABc8b16a2b,
+    mkldnn_OIw8i8o = mkldnn_ABc8b8a,
+    mkldnn_OIw8o16i2o = mkldnn_ABc8a16b2a,
+    mkldnn_OIw8o8i = mkldnn_ABc8a8b,
+    mkldnn_Owi16o = mkldnn_Acb16a,
+    mkldnn_Owi4o = mkldnn_Acb4a,
+    mkldnn_Owi8o = mkldnn_Acb8a,
+
+    /* weights, 4D */
+    mkldnn_IOhw16o16i = mkldnn_BAcd16a16b,
+    mkldnn_Ohwi16o = mkldnn_Acdb16a,
+    mkldnn_Ohwi4o = mkldnn_Acdb4a,
+    mkldnn_Ohwi8o = mkldnn_Acdb8a,
+    mkldnn_OIhw16i16o = mkldnn_ABcd16b16a,
+    mkldnn_OIhw16o16i = mkldnn_ABcd16a16b,
+    mkldnn_Oihw16o = mkldnn_Abcd16a,
+    mkldnn_OIhw4i16o4i = mkldnn_ABcd4b16a4b,
+    mkldnn_OIhw4i4o = mkldnn_ABcd4b4a,
+    mkldnn_Oihw4o = mkldnn_Abcd4a,
+    mkldnn_OIhw8i16o2i = mkldnn_ABcd8b16a2b,
+    mkldnn_OIhw8i8o = mkldnn_ABcd8b8a,
+    mkldnn_OIhw8o16i2o = mkldnn_ABcd8a16b2a,
+    mkldnn_OIhw8o8i = mkldnn_ABcd8a8b,
+
+    /* weights, 5D */
+    mkldnn_Odhwi16o = mkldnn_Acdeb16a,
+    mkldnn_Odhwi4o = mkldnn_Acdeb4a,
+    mkldnn_Odhwi8o = mkldnn_Acdeb8a,
+    mkldnn_OIdhw16i16o = mkldnn_ABcde16b16a,
+    mkldnn_OIdhw16o16i = mkldnn_ABcde16a16b,
+    mkldnn_Oidhw16o = mkldnn_Abcde16a,
+    mkldnn_OIdhw4i4o = mkldnn_ABcde4b4a,
+    mkldnn_Oidhw4o = mkldnn_Abcde4a,
+    mkldnn_OIdhw8i16o2i = mkldnn_ABcde8b16a2b,
+    mkldnn_OIdhw8i8o = mkldnn_ABcde8b8a,
+    mkldnn_OIdhw8o8i = mkldnn_ABcde8a8b,
+
+    /* weights w/ groups, 3D */
+    mkldnn_Goiw16g = mkldnn_Abcd16a,
+    mkldnn_gIOw16o16i = mkldnn_aCBd16b16c,
+    mkldnn_gOIw16i16o = mkldnn_aBCd16c16b,
+    mkldnn_gOIw16o16i = mkldnn_aBCd16b16c,
+    mkldnn_gOiw16o = mkldnn_aBcd16b,
+    mkldnn_gOIw4i16o4i = mkldnn_aBCd4c16b4c,
+    mkldnn_gOIw4i4o = mkldnn_aBCd4c4b,
+    mkldnn_gOiw4o = mkldnn_aBcd4b,
+    mkldnn_gOIw8i16o2i = mkldnn_aBCd8c16b2c,
+    mkldnn_gOIw8i8o = mkldnn_aBCd8c8b,
+    mkldnn_gOIw8o16i2o = mkldnn_aBCd8b16c2b,
+    mkldnn_gOIw8o8i = mkldnn_aBCd8b8c,
+    mkldnn_gOwi16o = mkldnn_aBdc16b,
+    mkldnn_gOwi4o = mkldnn_aBdc4b,
+    mkldnn_gOwi8o = mkldnn_aBdc8b,
+
+    /* weights w/ groups, 4D */
+    mkldnn_gIOhw16o16i = mkldnn_aCBde16b16c,
+    mkldnn_gOhwi16o = mkldnn_aBdec16b,
+    mkldnn_gOhwi4o = mkldnn_aBdec4b,
+    mkldnn_gOhwi8o = mkldnn_aBdec8b,
+    mkldnn_Goihw16g = mkldnn_Abcde16a,
+    mkldnn_gOIhw16i16o = mkldnn_aBCde16c16b,
+    mkldnn_gOIhw16o16i = mkldnn_aBCde16b16c,
+    mkldnn_gOihw16o = mkldnn_aBcde16b,
+    mkldnn_gOIhw2i8o4i = mkldnn_aBCde2c8b4c,
+    mkldnn_gOIhw4i16o4i = mkldnn_aBCde4c16b4c,
+    mkldnn_gOIhw4i4o = mkldnn_aBCde4c4b,
+    mkldnn_gOIhw4o4i = mkldnn_aBCde4b4c,
+    mkldnn_gOihw4o = mkldnn_aBcde4b,
+    mkldnn_Goihw8g = mkldnn_Abcde8a,
+    mkldnn_gOIhw8i16o2i = mkldnn_aBCde8c16b2c,
+    mkldnn_gOIhw8i8o = mkldnn_aBCde8c8b,
+    mkldnn_gOIhw8o16i2o = mkldnn_aBCde8b16c2b,
+    mkldnn_gOIhw8o8i = mkldnn_aBCde8b8c,
+
+    /* weights w/ groups, 6D */
+    mkldnn_gOdhwi16o = mkldnn_aBdefc16b,
+    mkldnn_gOdhwi4o = mkldnn_aBdefc4b,
+    mkldnn_gOdhwi8o = mkldnn_aBdefc8b,
+    mkldnn_gOIdhw16i16o = mkldnn_aBCdef16c16b,
+    mkldnn_gOIdhw16o16i = mkldnn_aBCdef16b16c,
+    mkldnn_gOidhw16o = mkldnn_aBcdef16b,
+    mkldnn_gOIdhw4i4o = mkldnn_aBCdef4c4b,
+    mkldnn_gOidhw4o = mkldnn_aBcdef4b,
+    mkldnn_gOIdhw8i16o2i = mkldnn_aBCdef8c16b2c,
+    mkldnn_gOIdhw8i8o = mkldnn_aBCdef8c8b,
+    mkldnn_gOIdhw8o8i = mkldnn_aBCdef8b8c,
+} mkldnn_format_tag_t;
+
+/** Kinds of padding. Define how to interpret the data in padding regions. */
+typedef enum {
+    /** The data in padding regions is zero. */
+    mkldnn_padding_zero,
+} mkldnn_padding_kind_t;
+
+/** Kinds of propagation. */
+typedef enum {
+    /* TODO: suggest renames */
+    /** Undefined propagation type. */
+    mkldnn_prop_kind_undef = 0,
+    /** Forward data propagation (training mode). In this mode primitives
+     * perform computations necessary for subsequent backward propagation. */
+    mkldnn_forward_training = 64,
+    /** Forward data propagation (inference mode). In this mode primitives
+     * perform only computations that are necessary for inference and omit
+     * computations that are necessary only for backward propagation. */
+    mkldnn_forward_inference = 96,
+    /** Forward data propagation (alias for @c mkldnn_forward_inference) */
+    mkldnn_forward_scoring = mkldnn_forward_inference,
+   /** Forward data propagation (alias for @c mkldnn_forward_training) */
+    mkldnn_forward = mkldnn_forward_training,
+    /** Backward propagation (with respect to all parameters */
+    mkldnn_backward = 128,
+    /** Backward data propagation */
+    mkldnn_backward_data = 160,
+    /** Backward weights propagation */
+    mkldnn_backward_weights = 192,
+    /** Backward bias propagation */
+    mkldnn_backward_bias = 193,
+} mkldnn_prop_kind_t;
+
+/** Kinds of primitives. Used to implement a way to extend the library with new
+ * primitives without changing the ABI. */
+typedef enum {
+    /** Undefined primitive (XXX: why do we have it?). */
+    mkldnn_undefined_primitive,
+    /** A reorder primitive.*/
+    mkldnn_reorder,
+    /** A shuffle primitive.*/
+    mkldnn_shuffle,
+    /** A (out-of-place) concat primitive. */
+    mkldnn_concat,
+    /** A sum primitive. */
+    mkldnn_sum,
+    /** A convolution primitive. */
+    mkldnn_convolution,
+    /** A deconvolution primitive. */
+    mkldnn_deconvolution,
+    /** An element-wise primitive. */
+    mkldnn_eltwise,
+    /** A Softmax primitive. */
+    mkldnn_softmax,
+    /** A pooling primitive. */
+    mkldnn_pooling,
+    /** An LRN primitive. */
+    mkldnn_lrn,
+    /** An batch normalization primitive. */
+    mkldnn_batch_normalization,
+    /** An inner product primitive. */
+    mkldnn_inner_product,
+    /** A rnn primitive. */
+    mkldnn_rnn,
+} mkldnn_primitive_kind_t;
+
+/** Kinds of algorithms. */
+typedef enum {
+    mkldnn_alg_kind_undef,
+    /** Direct convolution */
+    mkldnn_convolution_direct = 0x1,
+    /** Winograd convolution */
+    mkldnn_convolution_winograd = 0x2,
+    /** Convolution algorithm(either direct or Winograd) is chosen just in time **/
+    mkldnn_convolution_auto = 0x3,
+    /** Direct deconvolution */
+    mkldnn_deconvolution_direct = 0xa,
+    /** Winograd deconvolution */
+    mkldnn_deconvolution_winograd = 0xb,
+    /** Eltwise: ReLU */
+    mkldnn_eltwise_relu = 0x1f,
+    /** Eltwise: hyperbolic tangent non-linearity (tanh) */
+    mkldnn_eltwise_tanh = 0x2f,
+    /** Eltwise: parametric exponential linear unit (elu) */
+    mkldnn_eltwise_elu = 0x3f,
+    /** Eltwise: square */
+    mkldnn_eltwise_square = 0x4f,
+    /** Eltwise: abs */
+    mkldnn_eltwise_abs = 0x5f,
+    /** Eltwise: square root */
+    mkldnn_eltwise_sqrt = 0x6f,
+    /** Eltwise: linear */
+    mkldnn_eltwise_linear = 0x7f,
+    /** Eltwise: bounded_relu */
+    mkldnn_eltwise_bounded_relu = 0x8f,
+    /** Eltwise: soft_relu */
+    mkldnn_eltwise_soft_relu = 0x9f,
+    /** Eltwise: logistic */
+    mkldnn_eltwise_logistic = 0xaf,
+    /** Max pooling */
+    mkldnn_pooling_max = 0x1ff,
+    /** Average pooling include padding */
+    mkldnn_pooling_avg_include_padding = 0x2ff,
+    /** Average pooling exclude padding */
+    mkldnn_pooling_avg_exclude_padding = 0x3ff,
+    mkldnn_pooling_avg = mkldnn_pooling_avg_exclude_padding,
+    /** Local response normalization (LRN) across multiple channels */
+    mkldnn_lrn_across_channels = 0xaff,
+    /** LRN within a single channel */
+    mkldnn_lrn_within_channel = 0xbff,
+    /** RNN cell */
+    mkldnn_vanilla_rnn = 0x1fff,
+    /** LSTM cell */
+    mkldnn_vanilla_lstm = 0x2fff,
+    /** GRU cell */
+    mkldnn_vanilla_gru = 0x3fff,
+    /** GRU cell with linear before reset
+     *
+     * Modification of original GRU cell. Differs from #mkldnn_vanilla_gru
+     * in how the new memory gate is calculated:
+     * \f[ c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f]
+     * Primitive expects 4 biases on input:
+     * \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$
+     * */
+    mkldnn_gru_linear_before_reset = 0x4fff,
+} mkldnn_alg_kind_t;
+
+/** Flags for batch-normalization primititve. */
+typedef enum {
+    /** Use global statistics
+     *
+     * If specified
+     *  - on forward propagation use mean and variance provided by user (input)
+     *  - on backward propagation reduces the amount of computations, since
+     *    mean and variance are considered as constants
+     *
+     *  If not specified:
+     *   - on forward propagation mean and variance are computed and stored in
+     *     output
+     *   - on backward propagation compute full derivative wrt to data
+     */
+    mkldnn_use_global_stats = 0x1U,
+    /** Use scale and shift parameters
+     *
+     * If specified:
+     *  - on forward propagation use scale and shift (aka scale and bias) for
+     *    the batch normalization results
+     *  - on backward propagation (for prop_kind == #mkldnn_backward) compute
+     *    diff wrt to scale and shift (hence one extra output used)
+     *
+     * If no specified:
+     *  - on backward propagation prop_kind == #mkldnn_backward_data has the
+     *    same behavior as prop_kind == #mkldnn_backward
+     */
+    mkldnn_use_scaleshift = 0x2U,
+    /** Fuse with ReLU
+     *
+     * If specified:
+     *  - on inference this option behaves the same as if the primitive were
+     *    fused with ReLU via post ops API
+     *  - on training primitive requires workspace (required to be able to
+     *    perform backward pass)
+     */
+    mkldnn_fuse_bn_relu = 0x4U,
+} mkldnn_batch_normalization_flag_t;
+
+/** @} */
+
+/** @addtogroup c_api_types_memory Memory
+ *  @{ */
+
+/** Maximum number of dimensions a tensor can have. Only restricts the amount
+ * of space used for the tensor description. Individual computational
+ * primitives may support only tensors of certain dimensions. */
+#define MKLDNN_MAX_NDIMS 12
+
+/** A type to describe tensor dimension. */
+typedef int64_t mkldnn_dim_t;
+
+/** A type to describe tensor dimensions. */
+typedef mkldnn_dim_t mkldnn_dims_t[MKLDNN_MAX_NDIMS];
+
+/** A type to describe strides within a tensor. */
+typedef mkldnn_dim_t mkldnn_strides_t[MKLDNN_MAX_NDIMS];
+
+/** Generic description of blocked data layout for most memory formats.
+ *
+ * @sa @ref understanding_memory_formats */
+typedef struct {
+    /** The strides between the outermost blocks.
+     * In case of plain (non-blocked) formats the strides between dimensions. */
+    mkldnn_dims_t strides;
+    /* Innermost section
+     * ASSUMPTION: the innermost blocks are always dense */
+    /** The number of innermost blocks, e.g. 3 in case of `OIhw_4i16o4i_` */
+    int inner_nblks;
+    /** The size of the blocks, e.g. `{4, 16, 4}` in case of `OIhw_4i16o4i` */
+    mkldnn_dims_t inner_blks;
+    /** The logical indices of the blocks, e.g. `{1, 0, 1}` in case of
+     * `4i16o4i`, because `i` is the 1st dim and `o` is the 0st dim */
+    mkldnn_dims_t inner_idxs;
+} mkldnn_blocking_desc_t;
+
+typedef enum {
+    /** Undefined memory format, used for empty memory descriptors. */
+    mkldnn_wino_undef = 0,
+    /** Tensors of weights for 2x3 winograd convolutions. */
+    mkldnn_wino_wei_aaOIoi,
+    mkldnn_wino_wei_aaOio,
+    mkldnn_wino_wei_aaOBiOo,
+    /** Tensor of weights for 4x3 convolution. */
+    mkldnn_wino_wei_OBaaIBOIio
+} mkldnn_wino_memory_format_t;
+
+/** Description of tensor of weights for winograd 2x3 convolution. */
+typedef struct {
+    mkldnn_wino_memory_format_t wino_format;
+    int r;
+    int alpha;
+    int ic;
+    int oc;
+    int ic_block;
+    int oc_block;
+    int ic2_block;
+    int oc2_block;
+    float adj_scale;
+    size_t size;
+} mkldnn_wino_desc_t;
+
+typedef enum {
+    mkldnn_packed_format_undef = 0,
+    mkldnn_ldigo_p,
+    mkldnn_ldgoi_p
+} mkldnn_rnn_packed_memory_format_t;
+
+/* Maximum number of parts of RNN weights tensor that require separate
+ * computation. */
+#define MKLDNN_RNN_MAX_N_PARTS 4
+
+/** Description of tensor of packed weights for rnn. */
+typedef struct {
+    mkldnn_rnn_packed_memory_format_t format;
+    int n_parts;
+    int n;
+    int parts[MKLDNN_RNN_MAX_N_PARTS];
+    size_t part_pack_size[MKLDNN_RNN_MAX_N_PARTS];
+    size_t offset_compensation;
+    size_t size;
+} mkldnn_rnn_packed_desc_t;
+
+typedef enum {
+    mkldnn_memory_extra_flag_none = 0x0U,
+    /** Indicates the weights have an additional buffer, that depends on the
+     * @p compensation_mask.
+     *
+     * For instance, in 4D case with the compensation mask equals (1 << 0)
+     * the additional buffer would consist of OC values:
+     * O[oc : 0,OC] =
+     *  -128 * SUM(ic : 0,IC; kh : 0,KH; kw : 0,KW){ weights(oc, ic, kh, kw) }
+     */
+    mkldnn_memory_extra_flag_compensation_conv_s8s8 = 0x1U,
+    mkldnn_memory_extra_flag_scale_adjust = 0x2U,
+} mkldnn_memory_extra_flags_t;
+
+/** Description of extra information stored in memory */
+typedef struct {
+    /** The flags contain arbitrary extra information, such as compensation.
+     * @sa mkldnn_memory_extra_flags_t */
+    uint64_t flags;
+    /** Compensation mask */
+    int compensation_mask;
+    /** Scale applied to the data */
+    float scale_adjust;
+    /** For future backwards compatibility */
+    char reserved[64];
+} mkldnn_memory_extra_desc_t;
+
+/** Memory descriptor. The description is based on a number of dimensions,
+ * dimensions themselves, plus information about elements type and memory
+ * format. Additionally, contains format-specific descriptions of the data
+ * layout. */
+typedef struct {
+    /** Number of dimensions */
+    int ndims;
+    /** Dimensions in the following order:
+     * - CNN data tensors: mini-batch, channel, spatial
+     *   (<code>{N, C, [[D,] H,] W}</code>)
+     * - CNN weight tensors: group (optional), output channel, input channel,
+     *   spatial (<code>{[G,] O, I, [[D,] H,] W}</code>)
+     * - RNN data tensors: time, mini-batch, channels (<code>{T, N, C}</code>)
+     *   or layers, directions, states, mini-batch, channels (<code>{L, D, S, N, C}</code>)
+     * - RNN weight tensor: layers, directions, input channel, gates, output channels
+     *   (<code>{L, D, I, G, O}</code>).
+     *
+     * @note
+     *    The order of dimensions does not depend on the memory format, so
+     *    whether the data is laid out in #mkldnn_nchw or #mkldnn_nhwc
+     *    the dims for 4D CN data tensor would be <code>{N, C, H, W}</code>.
+     */
+    mkldnn_dims_t dims;
+    /** Data type of the tensor elements. */
+    mkldnn_data_type_t data_type;
+
+    /** Size of the data including padding in each dimension. */
+    mkldnn_dims_t padded_dims;
+    /** Per-dimension offset from the padding to actual data, the top-level
+     * tensor with offsets applied must lie within the padding area. */
+    mkldnn_dims_t padded_offsets;
+
+    /** Offset from memory origin to the current block, non-zero only in
+     * a description of a memory sub-block. */
+    mkldnn_dim_t offset0;
+
+    /** Memory format kind. */
+    mkldnn_format_kind_t format_kind;
+    union {
+        /** Description of the data layout for memory formats that use
+         * blocking. */
+        mkldnn_blocking_desc_t blocking;
+        /** Tensor of weights for integer 8bit winograd convolution. */
+        mkldnn_wino_desc_t wino_desc;
+        /** Tensor of packed weights for RNN. */
+        mkldnn_rnn_packed_desc_t rnn_packed_desc;
+        /* ... other descriptions possible */
+    } format_desc;
+
+    mkldnn_memory_extra_desc_t extra;
+} mkldnn_memory_desc_t;
+
+/** @struct mkldnn_memory
+ * An opaque structure to describe a memory. */
+struct mkldnn_memory;
+
+/** A memory handle. */
+typedef struct mkldnn_memory *mkldnn_memory_t;
+
+/** A constant memory handle. */
+typedef const struct mkldnn_memory *const_mkldnn_memory_t;
+
+#define MKLDNN_NATIVE_HANDLE_NONE (NULL)
+#define MKLDNN_NATIVE_HANDLE_ALLOCATE ((void *)(size_t)-1)
+
+/** @} */
+
+/** @addtogroup c_api_types_op_descs Operation descriptors
+ *  @{*/
+
+/** A pointer to any of the operation descriptors. */
+typedef void *mkldnn_op_desc_t;
+/** A pointer to any of the operation descriptors (constant variant). */
+typedef const void *const_mkldnn_op_desc_t;
+
+/** A descriptor of a convolution operation. */
+typedef struct {
+    /** The kind of primitive. Used for self-identifying the primitive
+     * descriptor. Must be #mkldnn_convolution. */
+    mkldnn_primitive_kind_t primitive_kind;
+    /** The kind of propagation. Possible values: #mkldnn_forward_training,
+     * #mkldnn_forward_inference, #mkldnn_backward_data,
+     * #mkldnn_backward_weights, and #mkldnn_backward_bias. */
+    mkldnn_prop_kind_t prop_kind;
+    /** The kind of the convolution algorithm. Possible values:
+     * #mkldnn_convolution_direct. */
+    mkldnn_alg_kind_t alg_kind;
+    /** Source memory descriptor. */
+    mkldnn_memory_desc_t src_desc;
+    /** Source gradient memory descriptor. */
+    mkldnn_memory_desc_t diff_src_desc;
+    /** Weights memory descriptor. */
+    mkldnn_memory_desc_t weights_desc;
+    /** Weights gradient memory descriptor. */
+    mkldnn_memory_desc_t diff_weights_desc;
+    /** Bias memory descriptor. */
+    mkldnn_memory_desc_t bias_desc;
+    /** Bias gradient memory descriptor. */
+    mkldnn_memory_desc_t diff_bias_desc;
+    /** Destination memory descriptor. */
+    mkldnn_memory_desc_t dst_desc;
+    /** Destination gradient memory descriptor. */
+    mkldnn_memory_desc_t diff_dst_desc;
+    /** Convolution strides in each spatial dimension. */
+    mkldnn_dims_t strides;
+    /** Convolution dilates in each spatial dimension. */
+    mkldnn_dims_t dilates;
+    /** Padding in each spatial dimension. padding[0] is a padding in the
+     * beginning (@p padding_l), padding[1] is a padding in the end (@p
+     * padding_r). */
+    mkldnn_dims_t padding[2];
+    /** The kind of padding to use. */
+    mkldnn_padding_kind_t padding_kind;
+    /** The accumulator data type. Initialized automatically. */
+    mkldnn_data_type_t accum_data_type;
+} mkldnn_convolution_desc_t;
+
+/** A descriptor of a deconvolution operation. */
+typedef mkldnn_convolution_desc_t mkldnn_deconvolution_desc_t;
+
+/** A descriptor of a shuffle operation. */
+typedef struct {
+    /** The kind of primitive. Used for self-identifying the primitive
+     * descriptor. Must be #mkldnn_convolution. */
+    mkldnn_primitive_kind_t primitive_kind;
+    /** The kind of propagation. Possible values: #mkldnn_forward_training,
+     * #mkldnn_forward_inference, and #mkldnn_backward_data. */
+    mkldnn_prop_kind_t prop_kind;
+    /** Source and destination memory descriptor,
+     *  and source and destination gradient memory descriptor. */
+    mkldnn_memory_desc_t data_desc;
+    /** axis for shuffling. */
+    int axis;
+    /** number of groups in group convolution */
+    mkldnn_dim_t group_size;
+} mkldnn_shuffle_desc_t;
+
+/** A descriptor of a element-wise operation. */
+typedef struct {
+    /** The kind of primitive. Used for self-identifying the primitive
+     * descriptor. Must be #mkldnn_eltwise. */
+    mkldnn_primitive_kind_t primitive_kind;
+    /** The kind of propagation. Possible values: #mkldnn_forward_training,
+     * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
+     */
+    mkldnn_prop_kind_t prop_kind;
+    /** The kind of eltwise algorithm. Possible values: #mkldnn_eltwise_relu,
+     * #mkldnn_eltwise_tanh, #mkldnn_eltwise_elu, #mkldnn_eltwise_square,
+     * #mkldnn_eltwise_abs, #mkldnn_eltwise_sqrt, #mkldnn_eltwise_linear,
+     * #mkldnn_eltwise_bounded_relu, #mkldnn_eltwise_soft_relu, and
+     * #mkldnn_eltwise_logistic. */
+    mkldnn_alg_kind_t alg_kind;
+    /** Source and destination memory descriptor. */
+    mkldnn_memory_desc_t data_desc;
+    /** Source and destination gradient memory descriptor. */
+    mkldnn_memory_desc_t diff_data_desc;
+    /** Algorithm specific parameter.
+     * Accordance table:
+     *  - #mkldnn_eltwise_relu: @p alpha -- negative slope, @p beta ignored
+     *  - #mkldnn_eltwise_tanh: @p alpha and @p beta ignored
+     *  - #mkldnn_eltwise_elu: @p alpha -- negative slope, @p beta ignored
+     *  - #mkldnn_eltwise_square: @p alpha and @p beta ignored
+     *  - #mkldnn_eltwise_abs: @p alpha and @p beta ignored
+     *  - #mkldnn_eltwise_sqrt: @p alpha and @p beta ignored
+     *  - #mkldnn_eltwise_linear: @p alpha -- scale, @p beta -- shift
+     *  - #mkldnn_eltwise_bounded_relu: @p alpha -- upper bound, @p beta ignored
+     *  - #mkldnn_eltwise_soft_relu: @p alpha and @p beta ignored
+     *  - #mkldnn_eltwise_logistic: @p alpha and @p beta ignored
+     */
+    float alpha, beta;
+} mkldnn_eltwise_desc_t;
+
+/** A descriptor of a Softmax operation. */
+typedef struct {
+    /** The kind of primitive. Used for self-identifying the primitive
+    * descriptor. Must be #mkldnn_softmax. */
+    mkldnn_primitive_kind_t primitive_kind;
+    /** The kind of propagation. Possible values: #mkldnn_forward_training and
+     * #mkldnn_forward_inference. */
+    mkldnn_prop_kind_t prop_kind;
+    /** Source and destination memory descriptor. */
+    mkldnn_memory_desc_t data_desc;
+    /** Source and Destination of gradient memory descriptor. */
+    mkldnn_memory_desc_t diff_desc;
+    /** The axis along which to perform the softmax. */
+    int softmax_axis;
+} mkldnn_softmax_desc_t;
+
+/** A descriptor of a pooling operation. */
+typedef struct {
+    /** The kind of primitive. Used for self-identifying the primitive
+     * descriptor. Must be #mkldnn_pooling. */
+    mkldnn_primitive_kind_t primitive_kind;
+    /** The kind of propagation. Possible values: #mkldnn_forward_training,
+     * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
+     */
+    mkldnn_prop_kind_t prop_kind;
+    /** The kind of pooling algorithm. Possible values: #mkldnn_pooling_max and
+     * #mkldnn_pooling_avg. */
+    mkldnn_alg_kind_t alg_kind;
+    /** Source memory descriptor. */
+    mkldnn_memory_desc_t src_desc;
+    /** Source gradient memory descriptor. */
+    mkldnn_memory_desc_t diff_src_desc;
+    /** Destination memory descriptor. */
+    mkldnn_memory_desc_t dst_desc;
+    /** Destination gradient memory descriptor. */
+    mkldnn_memory_desc_t diff_dst_desc;
+    /** Pooling kernel strides for spatial dimensions. */
+    mkldnn_dims_t strides;
+    /** Pooling kernel spatial dimensions. */
+    mkldnn_dims_t kernel;
+    /** Padding in each spatial dimension. padding[0] is a padding in the
+     * beginning (@p padding_l), padding[1] is a padding in the end (@p
+     * padding_r). */
+    mkldnn_dims_t padding[2];
+    /** The kind of padding to use. */
+    mkldnn_padding_kind_t padding_kind;
+    /** The accumulator data type. Initialized automatically. */
+    mkldnn_data_type_t accum_data_type;
+} mkldnn_pooling_desc_t;
+
+/** A descriptor of a Local Response Normalization (LRN) operation. */
+typedef struct {
+    /** The kind of primitive. Used for self-identifying the primitive
+     * descriptor. Must be #mkldnn_lrn. */
+    mkldnn_primitive_kind_t primitive_kind;
+    /** The kind of propagation. Possible values: #mkldnn_forward_training,
+     * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
+     */
+    mkldnn_prop_kind_t prop_kind;
+    /** LRN algorithm. Possible values: #mkldnn_lrn_within_channel and
+     * #mkldnn_lrn_across_channels. */
+    mkldnn_alg_kind_t alg_kind;
+    /** Source and destination memory descriptor. */
+    mkldnn_memory_desc_t data_desc;
+    /** Source and destination gradient memory descriptor. */
+    mkldnn_memory_desc_t diff_data_desc;
+    /** The number of channels to sum over (for cross-channel LRN) or the side
+     * length of the square region to sum over (for within-channel LRN). */
+    mkldnn_dim_t local_size;
+    /** LRN alpha parameter. */
+    float lrn_alpha;
+    /** LRN beta parameter. */
+    float lrn_beta;
+    /** LRN k parameter. */
+    float lrn_k;
+} mkldnn_lrn_desc_t;
+
+/** A descriptor of a Batch Normalization operation. */
+typedef struct {
+    /** The kind of primitive. Used for self-identifying the primitive
+     * descriptor. Must be #mkldnn_batch_normalization. */
+    mkldnn_primitive_kind_t primitive_kind;
+    /** The kind of propagation. Possible values: #mkldnn_forward_training,
+     * #mkldnn_forward_inference, #mkldnn_backward, and #mkldnn_backward_data.
+     */
+    mkldnn_prop_kind_t prop_kind;
+    /** Source and destination memory descriptor. */
+    mkldnn_memory_desc_t data_desc;
+    /** Source and destination gradient memory descriptor. */
+    mkldnn_memory_desc_t diff_data_desc;
+    /** Scale and shift data and gradient memory descriptors.
+     *
+     * Scaleshift memory descriptor uses 2D #mkldnn_nc format[2,Channels]. 1-st
+     * dimension contains gamma parameter, 2-nd dimension contains beta
+     * parameter. */
+    mkldnn_memory_desc_t data_scaleshift_desc;
+    mkldnn_memory_desc_t diff_data_scaleshift_desc;
+    /** Mean and variance data memory descriptors.
+     *
+     * Mean and variance memory descriptors use 1D #mkldnn_x format[Channels].
+     */
+    mkldnn_memory_desc_t mean_desc;
+    mkldnn_memory_desc_t variance_desc;
+    /** Batch normalization epsilon parameter. */
+    float batch_norm_epsilon;
+    unsigned flags;
+} mkldnn_batch_normalization_desc_t;
+
+/** A descriptor of an inner product operation. */
+typedef struct {
+    /** The kind of primitive. Used for self-identifying the primitive
+     * descriptor. Must be #mkldnn_inner_product. */
+    mkldnn_primitive_kind_t primitive_kind;
+    /** The kind of propagation. Possible values: #mkldnn_forward_training,
+     * #mkldnn_forward_inference, #mkldnn_backward_data,
+     * #mkldnn_backward_weights, and #mkldnn_backward_bias. */
+    mkldnn_prop_kind_t prop_kind;
+    /** Source memory descriptor. */
+    mkldnn_memory_desc_t src_desc;
+    /** Source gradient memory descriptor. */
+    mkldnn_memory_desc_t diff_src_desc;
+    /** Weights memory descriptor. */
+    mkldnn_memory_desc_t weights_desc;
+    /** Weights gradient memory descriptor. */
+    mkldnn_memory_desc_t diff_weights_desc;
+    /** Bias memory descriptor. */
+    mkldnn_memory_desc_t bias_desc;
+    /** Bias gradient memory descriptor. */
+    mkldnn_memory_desc_t diff_bias_desc;
+    /** Destination memory descriptor. */
+    mkldnn_memory_desc_t dst_desc;
+    /** Destination gradient memory descriptor. */
+    mkldnn_memory_desc_t diff_dst_desc;
+    /** The accumulator data type. Initialized automatically. */
+    mkldnn_data_type_t accum_data_type;
+} mkldnn_inner_product_desc_t;
+
+/** Flags for RNN cell. */
+typedef enum {
+    mkldnn_rnn_cell_with_relu = 0x1U,
+    mkldnn_rnn_cell_with_clipping = 0x2U,
+} mkldnn_rnn_cell_flags_t;
+
+typedef struct {
+    /** RNN cell kind. Must be one of #mkldnn_vanilla_rnn,
+     * #mkldnn_vanilla_lstm, #mkldnn_vanilla_gru,
+     * or #mkldnn_gru_linear_before_reset. */
+    mkldnn_alg_kind_t cell_kind;
+    /** Activation function used. Must be either #mkldnn_eltwise_relu or
+     * #mkldnn_eltwise_tanh. */
+    mkldnn_alg_kind_t activation_kind;
+    /** RNN cell flags */
+    unsigned int flags;
+    /** @c alpha is a negative slope parameter (used only if
+     * `(flags & #mkldnn_rnn_cell_with_relu) != 0`) */
+    float alpha;
+    /** clipping parameter (used only if
+     * `(flags & #mkldnn_rnn_cell_with_clipping) != 0`) */
+    float clipping;
+} mkldnn_rnn_cell_desc_t;
+
+/** A direction of RNN primitive execution. */
+typedef enum {
+    /* Unidirectional execution of RNN primitive from left to right. */
+    mkldnn_unidirectional_left2right,
+    /* Unidirectional execution of RNN primitive from right to left. */
+    mkldnn_unidirectional_right2left,
+    /* Bidirectional execution of RNN primitive with concatenation of the
+     * results. */
+    mkldnn_bidirectional_concat,
+    /* Bidirectional execution of RNN primitive with summation of the
+     * results. */
+    mkldnn_bidirectional_sum,
+    mkldnn_unidirectional = mkldnn_unidirectional_left2right,
+} mkldnn_rnn_direction_t;
+
+/** A descriptor for an RNN operation. */
+typedef struct {
+    /** The kind of primitive. Used for self-identifying the primitive
+     * descriptor. Must be #mkldnn_rnn. */
+    mkldnn_primitive_kind_t primitive_kind;
+    /** The kind of propagation. Possible values: #mkldnn_forward_training,
+     * #mkldnn_forward_inference, and #mkldnn_backward. */
+    mkldnn_prop_kind_t prop_kind;
+    /** The RNN cell desc. */
+    mkldnn_rnn_cell_desc_t cell_desc;
+    /** The direction of RNN primitive execution. */
+    mkldnn_rnn_direction_t direction;
+    /** Source layer memory descriptor. */
+    mkldnn_memory_desc_t src_layer_desc;
+    /** Source iteration memory descriptor. */
+    mkldnn_memory_desc_t src_iter_desc;
+    /** Weights layer memory descriptor. */
+    mkldnn_memory_desc_t weights_layer_desc;
+    /** Weights iteration memory descriptor. */
+    mkldnn_memory_desc_t weights_iter_desc;
+    /** Bias memory descriptor. */
+    mkldnn_memory_desc_t bias_desc;
+    /** Destination layer memory descriptor. */
+    mkldnn_memory_desc_t dst_layer_desc;
+    /** Destination iter memory descriptor. */
+    mkldnn_memory_desc_t dst_iter_desc;
+    /** Source gradient layer memory descriptor. */
+    mkldnn_memory_desc_t diff_src_layer_desc;
+    /** Source gradient iter memory descriptor. */
+    mkldnn_memory_desc_t diff_src_iter_desc;
+    /** Weights gradient layer memory descriptor. */
+    mkldnn_memory_desc_t diff_weights_layer_desc;
+    /** Weights gradient iter memory descriptor. */
+    mkldnn_memory_desc_t diff_weights_iter_desc;
+    /** Bias gradient memory descriptor. */
+    mkldnn_memory_desc_t diff_bias_desc;
+    /** Destination gradient layer memory descriptor. */
+    mkldnn_memory_desc_t diff_dst_layer_desc;
+    /** Destination gradient iteration memory descriptor. */
+    mkldnn_memory_desc_t diff_dst_iter_desc;
+} mkldnn_rnn_desc_t;
+
+/** @} */
+
+/** @addtogroup c_api_engine_types Engine
+ * @{ */
+
+/** @brief Kinds of engines. */
+typedef enum {
+    /** An unspecified engine. */
+    mkldnn_any_engine,
+    /** CPU engine. */
+    mkldnn_cpu,
+} mkldnn_engine_kind_t;
+
+/** @struct mkldnn_engine
+ * @brief An opaque structure to describe an engine. */
+struct mkldnn_engine;
+/** @brief An engine handle. */
+typedef struct mkldnn_engine *mkldnn_engine_t;
+#if 0
+/* FIXME: looks like this never happens */
+/** @brief A constant engine handle. */
+typedef const struct mkldnn_engine *const_mkldnn_engine_t;
+#endif
+
+/** @} */
+
+/** @addtogroup c_api_primitive_desc_iterators Primitive descriptor iterators
+ * @{ */
+
+/** @struct mkldnn_primitive_desc_iterator
+ * @brief An opaque structure to describe a primitive descriptor iterator. */
+struct mkldnn_primitive_desc_iterator;
+
+/** @brief A primitive descriptor iterator handle. */
+typedef struct mkldnn_primitive_desc_iterator
+    *mkldnn_primitive_desc_iterator_t;
+
+/** @brief A constant primitive descriptor iterator handle. */
+typedef const struct mkldnn_primitive_desc_iterator
+    *const_mkldnn_primitive_desc_iterator_t;
+
+/** @} */
+
+/** @addtogroup c_api_primitive_descs Primitive descriptors
+ * @{ */
+
+/** @struct mkldnn_primitive_desc
+ * @brief An opaque structure to describe a primitive descriptor. */
+struct mkldnn_primitive_desc;
+
+/** @brief A primitive descriptor handle. */
+typedef struct mkldnn_primitive_desc *mkldnn_primitive_desc_t;
+
+/** @brief A constant primitive descriptor handle. */
+typedef const struct mkldnn_primitive_desc *const_mkldnn_primitive_desc_t;
+
+/** @} */
+
+/** @addtogroup c_api_primitive_attr Primitive descriptor attributes
+ * @{ */
+
+/** Scratchpad mode */
+typedef enum {
+    /** The library manages scratchpad (default) */
+    mkldnn_scratchpad_mode_library,
+    /** A user shall query and provide the scratchpad memory to primitives */
+    mkldnn_scratchpad_mode_user,
+} mkldnn_scratchpad_mode_t;
+
+/** @struct mkldnn_primitive_attr
+ * @brief An opaque structure for primitive descriptor attributes.
+ *
+ * Attributes may contain:
+ *  - output scales (to scale the result prior to storing it to the memory)
+ */
+struct mkldnn_primitive_attr;
+
+/** @brief A primitive descriptor attributes handle that controls primitive
+ * behavior. */
+typedef struct mkldnn_primitive_attr *mkldnn_primitive_attr_t;
+
+/** @brief A constant primitive descriptor attributes handle. */
+typedef const struct mkldnn_primitive_attr *const_mkldnn_primitive_attr_t;
+
+/** @struct mkldnn_post_ops
+ * @brief An opaque structure for a chain of post operations.
+ *
+ * mkldnn_post_ops can be used to perform some (trivial) operations like
+ * accumulation or eltwise after certain primitives like convolution.
+ *
+ * Post operations might be combined together, making a chain of post
+ * operations. For instance one can configure convolution followed by
+ * accumulation followed by eltwise. This might be especially beneficial
+ * for residual learning blocks.
+ *
+ * @warning
+ *      Of course not all combinations are supported, so the user should handle
+ *      errors accordingly.
+ *
+ * Supported post operations:
+ *  - accumulation (base primitive: convolution)
+ *  - eltwise (base primitive: convolution)
+ */
+struct mkldnn_post_ops;
+
+/** @brief A post operation chain handle. */
+typedef struct mkldnn_post_ops *mkldnn_post_ops_t;
+
+/** @brief A constant post operation chain handle. */
+typedef const struct mkldnn_post_ops *const_mkldnn_post_ops_t;
+
+/** @} */
+
+/** @addtogroup c_api_types_primitive Primitive
+ * @{ */
+
+/** @struct mkldnn_primitive
+ * An opaque structure to describe a primitive. */
+struct mkldnn_primitive;
+/** A primitive handle. */
+typedef struct mkldnn_primitive *mkldnn_primitive_t;
+/** A constant primitive handle. */
+typedef const struct mkldnn_primitive *const_mkldnn_primitive_t;
+
+/** @addtogroup c_api_types_arguments Argument indices
+ * @{ */
+
+#define MKLDNN_ARG_SRC_0                1
+#define MKLDNN_ARG_SRC                  MKLDNN_ARG_SRC_0
+#define MKLDNN_ARG_SRC_LAYER            MKLDNN_ARG_SRC_0
+#define MKLDNN_ARG_FROM                 MKLDNN_ARG_SRC_0
+
+#define MKLDNN_ARG_SRC_1                2
+#define MKLDNN_ARG_SRC_ITER             MKLDNN_ARG_SRC_1
+
+#define MKLDNN_ARG_DST_0                17
+#define MKLDNN_ARG_DST                  MKLDNN_ARG_DST_0
+#define MKLDNN_ARG_TO                   MKLDNN_ARG_DST_0
+#define MKLDNN_ARG_DST_LAYER            MKLDNN_ARG_DST_0
+
+#define MKLDNN_ARG_DST_1                18
+#define MKLDNN_ARG_DST_ITER             MKLDNN_ARG_DST_1
+
+#define MKLDNN_ARG_WEIGHTS_0            33
+#define MKLDNN_ARG_WEIGHTS              MKLDNN_ARG_WEIGHTS_0
+#define MKLDNN_ARG_SCALE_SHIFT          MKLDNN_ARG_WEIGHTS_0
+#define MKLDNN_ARG_WEIGHTS_LAYER        MKLDNN_ARG_WEIGHTS_0
+
+#define MKLDNN_ARG_WEIGHTS_1            34
+#define MKLDNN_ARG_WEIGHTS_ITER         MKLDNN_ARG_WEIGHTS_1
+
+#define MKLDNN_ARG_BIAS                 41
+
+#define MKLDNN_ARG_MEAN                 49
+#define MKLDNN_ARG_VARIANCE             50
+
+#define MKLDNN_ARG_WORKSPACE            64
+#define MKLDNN_ARG_SCRATCHPAD           80
+
+#define MKLDNN_ARG_DIFF_SRC_0           129
+#define MKLDNN_ARG_DIFF_SRC             MKLDNN_ARG_DIFF_SRC_0
+#define MKLDNN_ARG_DIFF_SRC_LAYER       MKLDNN_ARG_DIFF_SRC_0
+
+#define MKLDNN_ARG_DIFF_SRC_1           130
+#define MKLDNN_ARG_DIFF_SRC_ITER        MKLDNN_ARG_DIFF_SRC_1
+
+#define MKLDNN_ARG_DIFF_DST_0           145
+#define MKLDNN_ARG_DIFF_DST             MKLDNN_ARG_DIFF_DST_0
+#define MKLDNN_ARG_DIFF_DST_LAYER       MKLDNN_ARG_DIFF_DST_0
+
+#define MKLDNN_ARG_DIFF_DST_1           146
+#define MKLDNN_ARG_DIFF_DST_ITER        MKLDNN_ARG_DIFF_DST_1
+
+#define MKLDNN_ARG_DIFF_WEIGHTS_0       161
+#define MKLDNN_ARG_DIFF_WEIGHTS         MKLDNN_ARG_DIFF_WEIGHTS_0
+#define MKLDNN_ARG_DIFF_SCALE_SHIFT     MKLDNN_ARG_DIFF_WEIGHTS_0
+#define MKLDNN_ARG_DIFF_WEIGHTS_LAYER   MKLDNN_ARG_DIFF_WEIGHTS_0
+
+#define MKLDNN_ARG_DIFF_WEIGHTS_1       162
+#define MKLDNN_ARG_DIFF_WEIGHTS_ITER    MKLDNN_ARG_DIFF_WEIGHTS_1
+
+#define MKLDNN_ARG_DIFF_BIAS            169
+
+#define MKLDNN_ARG_MULTIPLE_SRC         1024
+#define MKLDNN_ARG_MULTIPLE_DST         2048
+
+/** @} */
+
+/** An auxiliary structure to specify primitive's inputs/outputs at execution
+ *
+ * @warning
+ *      With this API it's impossible to preserve constness of memory, so all
+ *      memories are passed w/o const qualifier. However only memories with
+ *      output semantics might be changed during the execution */
+typedef struct {
+    int arg; /**< An argument index, e.g. MKLDNN_ARG_SRC */
+    mkldnn_memory_t memory; /**< Input/output memory */
+} mkldnn_exec_arg_t;
+
+/** @} */
+
+/** @addtogroup c_api_types_query Queries
+ * @{ */
+
+/** Primitive descriptor query specification
+ *
+ * For generic function mkldnn_primitive_desc_query(), the type of result must
+ * agree with the queried argument. The correspondence table:
+ *      Query                           | type of result
+ *      --------------------------------------------------------------
+ *      #mkldnn_query_engine            | mkldnn_engine_t *
+ *      #mkldnn_query_scratchpad_engine | mkldnn_engine_t *
+ *      #mkldnn_query_primitive_kind    | mkldnn_primitive_kind_t *
+ *      *_s32                           | int *
+ *      *_s64                           | mkldnn_dim_t * (same as int64_t *)
+ *      *_f64                           | double *
+ *      *_str                           | const char **
+ *      #mkldnn_query_op_d              | const_mkldnn_op_desc_t *
+ *      *_md                            | const mkldnn_memory_desc_t **
+ *      *_${op}_d                       | const mkldnn_${op}_desc_t **
+ *      *_pd                            | const_mkldnn_primitive_desc_t *
+ *
+ * @note
+ *     Rule of thumb: all opaque types and structures are returned by
+ *     reference. All numbers are returned by value.
+ *
+ * @warning
+ *     All returned references point to constant objects and are valid only
+ *     during the lifetime of the queried primitive descriptor. Returned objects
+ *     must not be destroyed by the user. If you need to keep the object longer
+ *     than the lifetime of the queried primitive descriptor, use
+ *     mkldnn_primitive_desc_clone() to make a copy. */
+typedef enum {
+    mkldnn_query_undef = 0,  /**< no query */
+
+    mkldnn_query_engine, /**< execution engine */
+    mkldnn_query_primitive_kind, /**< primitive kind */
+
+    mkldnn_query_num_of_inputs_s32, /**< number of inputs expected */
+    mkldnn_query_num_of_outputs_s32, /**< number of outputs expected */
+
+    mkldnn_query_time_estimate_f64, /**< runtime estimation (seconds) */
+    mkldnn_query_memory_consumption_s64, /**< memory consumption -- extra
+                                           (scratch) memory, additional to all
+                                           inputs and outputs memory (bytes) */
+
+    mkldnn_query_scratchpad_engine, /**< scratchpad engine -- engine to be used
+                                       for creating scratchpad memory */
+
+    mkldnn_query_impl_info_str, /**< implementation name */
+
+    /* memory and op descriptor section */
+    mkldnn_query_some_d = 64, /**< stub */
+    mkldnn_query_op_d, /**< op descriptor */
+    mkldnn_query_convolution_d, /**< convolution descriptor */
+    mkldnn_query_deconvolution_d, /**< deconvolution descriptor */
+    mkldnn_query_shuffle_d, /**< shuffle descriptor */
+    mkldnn_query_eltwise_d, /**< eltwise descriptor */
+    mkldnn_query_softmax_d, /**< softmax descriptor */
+    mkldnn_query_pooling_d, /**< pooling descriptor */
+    mkldnn_query_lrn_d, /**< lrn descriptor */
+    mkldnn_query_batch_normalization_d, /**< batch normalization descriptor */
+    mkldnn_query_inner_product_d, /**< inner product descriptor */
+    mkldnn_query_rnn_d, /**< rnn descriptor */
+
+    /* memory descriptor section */
+    mkldnn_query_some_md = 128, /**< stub */
+    mkldnn_query_src_md, /**< source memory desc */
+    mkldnn_query_diff_src_md, /**< source gradient memory desc */
+    mkldnn_query_weights_md, /**< weights memory descriptor desc */
+    mkldnn_query_diff_weights_md, /**< weights grad. memory desc */
+    mkldnn_query_dst_md, /**< destination memory desc */
+    mkldnn_query_diff_dst_md, /**< destination grad. memory desc */
+    mkldnn_query_workspace_md, /**< workspace memory desc */
+    mkldnn_query_scratchpad_md, /**< scratchpad memory desc */
+} mkldnn_query_t;
+
+/** @} */
+
+/** @addtogroup c_api_types_stream Execution stream
+ * @{ */
+
+/** @brief Stream flags. */
+typedef enum {
+    /** A default stream configuration. */
+    mkldnn_stream_default_flags = 0x0U,
+} mkldnn_stream_flags_t;
+
+/** @struct mkldnn_stream
+ * An opaque structure to describe an execution stream. */
+struct mkldnn_stream;
+/** An execution stream handle. */
+typedef struct mkldnn_stream *mkldnn_stream_t;
+/** A constant execution stream handle. */
+typedef const struct mkldnn_stream *const_mkldnn_stream_t;
+
+/** @} */
+/** @} */
+/** @} */
+
+#ifdef __cplusplus
+}
+#endif
+
+
+#endif

+ 32 - 0
thirdparty/oidn/mkl-dnn/include/mkldnn_version.h

@@ -0,0 +1,32 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MKLDNN_VERSION_H
+#define MKLDNN_VERSION_H
+
+/* Major version of MKL-DNN */
+#define MKLDNN_VERSION_MAJOR 0
+
+/* Minor version of MKL-DNN */
+#define MKLDNN_VERSION_MINOR 90
+
+/* Patch version of MKL-DNN */
+#define MKLDNN_VERSION_PATCH 0
+
+/* Git Commit Hash of MKL-DNN */
+#define MKLDNN_VERSION_HASH  "096bda1ca23324879f2df5a129e610e4405f775c"
+
+#endif

+ 32 - 0
thirdparty/oidn/mkl-dnn/include/mkldnn_version.h.in

@@ -0,0 +1,32 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MKLDNN_VERSION_H
+#define MKLDNN_VERSION_H
+
+/* Major version of MKL-DNN */
+#define MKLDNN_VERSION_MAJOR @MKLDNN_VERSION_MAJOR@
+
+/* Minor version of MKL-DNN */
+#define MKLDNN_VERSION_MINOR @MKLDNN_VERSION_MINOR@
+
+/* Patch version of MKL-DNN */
+#define MKLDNN_VERSION_PATCH @MKLDNN_VERSION_PATCH@
+
+/* Git Commit Hash of MKL-DNN */
+#define MKLDNN_VERSION_HASH  "@MKLDNN_VERSION_HASH@"
+
+#endif

+ 104 - 0
thirdparty/oidn/mkl-dnn/src/common/batch_normalization.cpp

@@ -0,0 +1,104 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t bnrm_desc_init(batch_normalization_desc_t *bnrm_desc,
+        prop_kind_t prop_kind, const memory_desc_t *data_desc,
+        const memory_desc_t *diff_data_desc, float epsilon, unsigned flags) {
+    bool args_ok = true
+        && !any_null(bnrm_desc, data_desc)
+        && one_of(prop_kind, forward_training, forward_inference,
+                backward_data, backward)
+        && IMPLICATION(prop_kind & backward, diff_data_desc != nullptr);
+    if (!args_ok) return invalid_arguments;
+
+    auto bd = batch_normalization_desc_t();
+    bd.primitive_kind = primitive_kind::batch_normalization;
+    bd.prop_kind = prop_kind;
+
+    bd.data_desc = *data_desc;
+    bd.diff_data_desc = zero_md();
+    if ( one_of(bd.prop_kind,backward_data, backward) )
+        bd.diff_data_desc = *diff_data_desc;
+
+    dims_t scaleshift_dims = { 2, data_desc->dims[1] };
+    mkldnn_memory_desc_init_by_tag(&bd.data_scaleshift_desc, 2,
+            scaleshift_dims, data_type::f32, mkldnn_nc);
+    bd.diff_data_scaleshift_desc = zero_md();
+    if (bd.prop_kind == backward) {
+        bd.diff_data_scaleshift_desc = bd.data_scaleshift_desc;
+    }
+
+    dims_t stats_dims = { data_desc->dims[1] };
+    mkldnn_memory_desc_init_by_tag(&bd.mean_desc, 1, stats_dims,
+            data_type::f32, mkldnn_x);
+    bd.variance_desc = bd.mean_desc;
+    bd.batch_norm_epsilon = epsilon;
+
+    unsigned bnorm_flags =
+        mkldnn_use_global_stats | mkldnn_use_scaleshift | mkldnn_fuse_bn_relu;
+    if ((~bnorm_flags & flags) != 0) return invalid_arguments;
+
+    bd.flags = flags;
+
+    bool consistency = true
+        && utils::one_of(bd.data_desc.ndims, 2, 4, 5);
+    if (bd.prop_kind == backward_data)
+        consistency = consistency
+            && utils::one_of(bd.diff_data_desc.ndims, 2, 4, 5)
+            && array_cmp(bd.diff_data_desc.dims, bd.data_desc.dims,
+                    bd.diff_data_desc.ndims);
+    if (!consistency) return invalid_arguments;
+
+    *bnrm_desc = bd;
+    return success;
+}
+}
+
+status_t mkldnn_batch_normalization_forward_desc_init(
+        batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind,
+        const memory_desc_t *data_desc, float epsilon, unsigned flags) {
+    if (!one_of(prop_kind, forward_training, forward_inference))
+        return invalid_arguments;
+    return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, nullptr,
+            epsilon, flags);
+}
+
+status_t mkldnn_batch_normalization_backward_desc_init(
+        batch_normalization_desc_t *bnrm_desc, prop_kind_t prop_kind,
+        const memory_desc_t *diff_data_desc, const memory_desc_t *data_desc,
+        float epsilon, unsigned flags) {
+    if (!one_of(prop_kind, backward, backward_data))
+        return invalid_arguments;
+    return bnrm_desc_init(bnrm_desc, prop_kind, data_desc, diff_data_desc,
+            epsilon, flags);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 240 - 0
thirdparty/oidn/mkl-dnn/src/common/batch_normalization_pd.hpp

@@ -0,0 +1,240 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef BATCH_NORMALIZATION_PD_HPP
+#define BATCH_NORMALIZATION_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct batch_normalization_fwd_pd_t;
+
+struct batch_normalization_pd_t: public primitive_desc_t {
+    static constexpr auto base_pkind = primitive_kind::batch_normalization;
+
+    batch_normalization_pd_t(engine_t *engine,
+            const batch_normalization_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const batch_normalization_fwd_pd_t *hint_fwd_pd)
+        : primitive_desc_t(engine, attr, base_pkind)
+        , desc_(*adesc)
+        , hint_fwd_pd_(hint_fwd_pd)
+        , data_md_(desc_.data_desc)
+        , stat_md_(desc_.mean_desc)
+        , scaleshift_md_(desc_.data_scaleshift_desc)
+        , ws_md_()
+    {}
+
+    const batch_normalization_desc_t *desc() const { return &desc_; }
+    virtual const op_desc_t *op_desc() const override
+    { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+    virtual void init_info() override { impl::init_info(this, this->info_); }
+
+    virtual status_t query(query_t what, int idx, void *result) const override {
+        switch (what) {
+        case query::batch_normalization_d:
+            *(const batch_normalization_desc_t**)result = desc(); break;
+        default: return primitive_desc_t::query(what, idx, result);
+        }
+        return status::success;
+    }
+
+    /* common batch_normalization aux functions */
+
+    dim_t MB() const { return data_desc().dims[0]; }
+    dim_t C() const { return data_desc().dims[1]; }
+    dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
+    dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
+    dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
+
+    int ndims() const { return desc_.data_desc.ndims; }
+
+    bool stats_is_src() const { return desc_.flags & mkldnn_use_global_stats; }
+    bool use_scaleshift() const { return desc_.flags & mkldnn_use_scaleshift; }
+    bool use_global_stats() const
+    { return desc_.flags & mkldnn_use_global_stats; }
+    bool fuse_bn_relu() const { return desc_.flags & mkldnn_fuse_bn_relu; }
+    bool with_relu_post_op() const {
+        const auto &p = this->attr()->post_ops_;
+        return p.len_ == 1 && p.entry_[0].is_relu(true, true);
+    }
+
+    bool is_fwd() const {
+        return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+                prop_kind::forward_inference);
+    }
+    bool is_bwd() const { return !this->is_fwd(); }
+    bool is_training() const
+    { return desc_.prop_kind == prop_kind::forward_training; }
+
+    bool has_zero_dim_memory() const
+    { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
+
+protected:
+    batch_normalization_desc_t desc_;
+    const batch_normalization_fwd_pd_t *hint_fwd_pd_;
+
+    memory_desc_t data_md_;
+    memory_desc_t stat_md_;
+    memory_desc_t scaleshift_md_;
+
+    memory_desc_t ws_md_;
+
+    void init_default_ws(size_t bits_per_element) {
+        const auto data_mdw = memory_desc_wrapper(data_md_);
+
+        const dim_t data_nelems = data_mdw.nelems(true);
+        const dim_t bits_per_byte = 8;
+        const dims_t ws_sz = { (dim_t)utils::div_up(
+                data_nelems * bits_per_element, bits_per_byte) };
+        mkldnn_memory_desc_init_by_tag(&ws_md_, 1, ws_sz, impl::data_type::u8,
+                format_tag::x);
+    }
+
+private:
+    const memory_desc_t &data_desc() const { return desc_.data_desc; }
+};
+
+struct batch_normalization_fwd_pd_t: public batch_normalization_pd_t {
+    typedef batch_normalization_fwd_pd_t base_class;
+    typedef batch_normalization_fwd_pd_t hint_class;
+
+    batch_normalization_fwd_pd_t(engine_t *engine,
+            const batch_normalization_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const batch_normalization_fwd_pd_t *hint_fwd_pd)
+        : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd)
+    {}
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (arg == MKLDNN_ARG_SRC) return arg_usage_t::input;
+        if (arg == MKLDNN_ARG_DST) return arg_usage_t::output;
+
+        if (utils::one_of(arg, MKLDNN_ARG_MEAN, MKLDNN_ARG_VARIANCE)) {
+            if (stats_is_src()) return arg_usage_t::input;
+            if (!stats_is_src() && is_training()) return arg_usage_t::output;
+            return arg_usage_t::unused;
+        }
+
+        if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift())
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_WORKSPACE && is_training() && fuse_bn_relu())
+            return arg_usage_t::output;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *src_md(int index = 0) const override {
+        if (index == 0) return &data_md_;
+        if (stats_is_src() && (index == 1 || index == 2)) return &stat_md_;
+        return nullptr;
+    }
+
+    virtual const memory_desc_t *dst_md(int index = 0) const override {
+        if (index == 0) return &data_md_;
+        if (!stats_is_src() && is_training() && (index == 1 || index == 2))
+            return &stat_md_;
+        return nullptr;
+    }
+
+    virtual const memory_desc_t *weights_md(int index = 0) const override
+    { return index == 0 ? &scaleshift_md_ : nullptr; }
+
+    virtual const memory_desc_t *workspace_md(int index = 0) const override
+    { return index == 0 && is_training() && fuse_bn_relu() ? &ws_md_ : nullptr; }
+
+    const memory_desc_t *stat_md() const
+    { return stats_is_src() ? src_md(1) : dst_md(1); }
+
+    virtual int n_inputs() const override
+    { return 1 + 2 * stats_is_src() + use_scaleshift(); }
+    virtual int n_outputs() const override
+    { return 1 + (fuse_bn_relu() + 2 * (!stats_is_src())) * is_training(); }
+};
+
+struct batch_normalization_bwd_pd_t: public batch_normalization_pd_t {
+    typedef batch_normalization_bwd_pd_t base_class;
+    typedef batch_normalization_fwd_pd_t hint_class;
+
+    batch_normalization_bwd_pd_t(engine_t *engine,
+            const batch_normalization_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const batch_normalization_fwd_pd_t *hint_fwd_pd)
+        : batch_normalization_pd_t(engine, adesc, attr, hint_fwd_pd)
+        , diff_data_md_(desc_.diff_data_desc)
+        , diff_scaleshift_md_(desc_.diff_data_scaleshift_desc)
+    {}
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_MEAN,
+                    MKLDNN_ARG_VARIANCE, MKLDNN_ARG_DIFF_DST))
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_SCALE_SHIFT && use_scaleshift())
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_WORKSPACE && fuse_bn_relu())
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_DIFF_SRC)
+            return arg_usage_t::output;
+
+        if (arg == MKLDNN_ARG_DIFF_SCALE_SHIFT && use_scaleshift())
+            return arg_usage_t::output;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *src_md(int index = 0) const override
+    { return index == 0 ? &data_md_ : index <= 2 ? &stat_md_ : nullptr; }
+    virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+    { return index == 0 ? &diff_data_md_ : nullptr; }
+    virtual const memory_desc_t *diff_src_md(int index = 0) const override
+    { return index == 0 ? &diff_data_md_ : nullptr; }
+
+    virtual const memory_desc_t *weights_md(int index = 0) const override
+    { return index == 0 ? &scaleshift_md_ : nullptr; }
+    virtual const memory_desc_t *diff_weights_md(int index = 0) const override
+    { return index == 0 ? &diff_scaleshift_md_ : nullptr; }
+
+    virtual const memory_desc_t *workspace_md(int index = 0) const override
+    { return index == 0 && fuse_bn_relu() ? &ws_md_ : nullptr; }
+
+    const memory_desc_t *stat_md() const { return src_md(1); }
+
+    virtual int n_inputs() const override
+    { return 4 + use_scaleshift() + fuse_bn_relu(); }
+    virtual int n_outputs() const override
+    { return 1 + (desc_.prop_kind == prop_kind::backward); }
+
+protected:
+    memory_desc_t diff_data_md_;
+    memory_desc_t diff_scaleshift_md_;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 550 - 0
thirdparty/oidn/mkl-dnn/src/common/c_types_map.hpp

@@ -0,0 +1,550 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef TYPE_MAPPING_HPP
+#define TYPE_MAPPING_HPP
+
+#include "mkldnn_types.h"
+
+namespace mkldnn {
+namespace impl {
+
+// TODO: autogenerate this
+
+using dim_t = mkldnn_dim_t;
+using dims_t = mkldnn_dims_t;
+using stride_t = mkldnn_dim_t;
+using strides_t = mkldnn_strides_t;
+
+using status_t = mkldnn_status_t;
+namespace status {
+    const status_t success = mkldnn_success;
+    const status_t out_of_memory = mkldnn_out_of_memory;
+    const status_t try_again = mkldnn_try_again;
+    const status_t invalid_arguments = mkldnn_invalid_arguments;
+    const status_t not_ready = mkldnn_not_ready;
+    const status_t unimplemented = mkldnn_unimplemented;
+    const status_t iterator_ends = mkldnn_iterator_ends;
+    const status_t runtime_error = mkldnn_runtime_error;
+    const status_t not_required = mkldnn_not_required;
+}
+
+using prop_kind_t = mkldnn_prop_kind_t;
+namespace prop_kind {
+    const prop_kind_t undef = mkldnn_prop_kind_undef;
+    const prop_kind_t forward_training = mkldnn_forward_training;
+    const prop_kind_t forward_inference = mkldnn_forward_inference;
+    const prop_kind_t forward_scoring = mkldnn_forward_scoring;
+    const prop_kind_t forward = mkldnn_forward;
+    const prop_kind_t backward = mkldnn_backward;
+    const prop_kind_t backward_data = mkldnn_backward_data;
+    const prop_kind_t backward_weights = mkldnn_backward_weights;
+    const prop_kind_t backward_bias = mkldnn_backward_bias;
+}
+
+using alg_kind_t = mkldnn_alg_kind_t;
+namespace alg_kind {
+    const alg_kind_t undef = mkldnn_alg_kind_undef;
+    const alg_kind_t convolution_auto = mkldnn_convolution_auto;
+    const alg_kind_t convolution_direct = mkldnn_convolution_direct;
+    const alg_kind_t convolution_winograd = mkldnn_convolution_winograd;
+    const alg_kind_t deconvolution_direct = mkldnn_deconvolution_direct;
+    const alg_kind_t deconvolution_winograd = mkldnn_deconvolution_winograd;
+    const alg_kind_t eltwise_relu = mkldnn_eltwise_relu;
+    const alg_kind_t eltwise_tanh = mkldnn_eltwise_tanh;
+    const alg_kind_t eltwise_elu = mkldnn_eltwise_elu;
+    const alg_kind_t eltwise_square = mkldnn_eltwise_square;
+    const alg_kind_t eltwise_abs = mkldnn_eltwise_abs;
+    const alg_kind_t eltwise_sqrt = mkldnn_eltwise_sqrt;
+    const alg_kind_t eltwise_linear = mkldnn_eltwise_linear;
+    const alg_kind_t eltwise_bounded_relu = mkldnn_eltwise_bounded_relu;
+    const alg_kind_t eltwise_soft_relu = mkldnn_eltwise_soft_relu;
+    const alg_kind_t eltwise_logistic = mkldnn_eltwise_logistic;
+    const alg_kind_t pooling_max = mkldnn_pooling_max;
+    const alg_kind_t pooling_avg = mkldnn_pooling_avg;
+    const alg_kind_t pooling_avg_include_padding = mkldnn_pooling_avg_include_padding;
+    const alg_kind_t pooling_avg_exclude_padding = mkldnn_pooling_avg_exclude_padding;
+    const alg_kind_t lrn_across_channels = mkldnn_lrn_across_channels;
+    const alg_kind_t lrn_within_channel = mkldnn_lrn_within_channel;
+    const alg_kind_t vanilla_rnn = mkldnn_vanilla_rnn;
+    const alg_kind_t vanilla_lstm = mkldnn_vanilla_lstm;
+    const alg_kind_t vanilla_gru = mkldnn_vanilla_gru;
+    const alg_kind_t gru_linear_before_reset = mkldnn_gru_linear_before_reset;
+}
+
+using data_type_t = mkldnn_data_type_t;
+namespace data_type {
+    const data_type_t undef = mkldnn_data_type_undef;
+    const data_type_t f32 = mkldnn_f32;
+    const data_type_t s32 = mkldnn_s32;
+    const data_type_t s8 = mkldnn_s8;
+    const data_type_t u8 = mkldnn_u8;
+}
+
+using scratchpad_mode_t = mkldnn_scratchpad_mode_t;
+namespace scratchpad_mode {
+    const scratchpad_mode_t library = mkldnn_scratchpad_mode_library;
+    const scratchpad_mode_t user = mkldnn_scratchpad_mode_user;
+}
+
+using rnn_packed_format_t = mkldnn_rnn_packed_memory_format_t;
+namespace rnn_packed_format {
+    const rnn_packed_format_t undef = mkldnn_packed_format_undef;
+    const rnn_packed_format_t ldigo_p = mkldnn_ldigo_p;
+    const rnn_packed_format_t ldgoi_p = mkldnn_ldgoi_p;
+}
+
+using format_kind_t = mkldnn_format_kind_t;
+namespace format_kind {
+    const format_kind_t undef = mkldnn_format_kind_undef;
+    const format_kind_t any = mkldnn_format_kind_any;
+    const format_kind_t blocked = mkldnn_blocked;
+    const format_kind_t wino = mkldnn_format_kind_wino;
+    const format_kind_t rnn_packed = mkldnn_format_kind_rnn_packed;
+}
+
+using format_tag_t = mkldnn_format_tag_t;
+namespace format_tag {
+    const format_tag_t undef = mkldnn_format_tag_undef;
+    const format_tag_t any = mkldnn_format_tag_any;
+    const format_tag_t a = mkldnn_a;
+    const format_tag_t ab = mkldnn_ab;
+    const format_tag_t abc = mkldnn_abc;
+    const format_tag_t abcd = mkldnn_abcd;
+    const format_tag_t abcde = mkldnn_abcde;
+    const format_tag_t abcdef = mkldnn_abcdef;
+    const format_tag_t abdec = mkldnn_abdec;
+    const format_tag_t acb = mkldnn_acb;
+    const format_tag_t acbde = mkldnn_acbde;
+    const format_tag_t acdb = mkldnn_acdb;
+    const format_tag_t acdeb = mkldnn_acdeb;
+    const format_tag_t ba = mkldnn_ba;
+    const format_tag_t bac = mkldnn_bac;
+    const format_tag_t bacd = mkldnn_bacd;
+    const format_tag_t bcda = mkldnn_bcda;
+    const format_tag_t cba = mkldnn_cba;
+    const format_tag_t cdba = mkldnn_cdba;
+    const format_tag_t cdeba = mkldnn_cdeba;
+    const format_tag_t decab = mkldnn_decab;
+    const format_tag_t Abc16a = mkldnn_Abc16a;
+    const format_tag_t ABc16a16b = mkldnn_ABc16a16b;
+    const format_tag_t aBc16b = mkldnn_aBc16b;
+    const format_tag_t ABc16b16a = mkldnn_ABc16b16a;
+    const format_tag_t Abc4a = mkldnn_Abc4a;
+    const format_tag_t aBc4b = mkldnn_aBc4b;
+    const format_tag_t ABc4b16a4b = mkldnn_ABc4b16a4b;
+    const format_tag_t ABc4b4a = mkldnn_ABc4b4a;
+    const format_tag_t ABc8a16b2a = mkldnn_ABc8a16b2a;
+    const format_tag_t ABc8a8b = mkldnn_ABc8a8b;
+    const format_tag_t aBc8b = mkldnn_aBc8b;
+    const format_tag_t ABc8b16a2b = mkldnn_ABc8b16a2b;
+    const format_tag_t ABc8b8a = mkldnn_ABc8b8a;
+    const format_tag_t Abcd16a = mkldnn_Abcd16a;
+    const format_tag_t ABcd16a16b = mkldnn_ABcd16a16b;
+    const format_tag_t aBcd16b = mkldnn_aBcd16b;
+    const format_tag_t ABcd16b16a = mkldnn_ABcd16b16a;
+    const format_tag_t aBCd16b16c = mkldnn_aBCd16b16c;
+    const format_tag_t aBCd16c16b = mkldnn_aBCd16c16b;
+    const format_tag_t Abcd4a = mkldnn_Abcd4a;
+    const format_tag_t aBcd4b = mkldnn_aBcd4b;
+    const format_tag_t ABcd4b16a4b = mkldnn_ABcd4b16a4b;
+    const format_tag_t ABcd4b4a = mkldnn_ABcd4b4a;
+    const format_tag_t aBCd4c16b4c = mkldnn_aBCd4c16b4c;
+    const format_tag_t aBCd4c4b = mkldnn_aBCd4c4b;
+    const format_tag_t ABcd8a16b2a = mkldnn_ABcd8a16b2a;
+    const format_tag_t ABcd8a8b = mkldnn_ABcd8a8b;
+    const format_tag_t aBcd8b = mkldnn_aBcd8b;
+    const format_tag_t ABcd8b16a2b = mkldnn_ABcd8b16a2b;
+    const format_tag_t aBCd8b16c2b = mkldnn_aBCd8b16c2b;
+    const format_tag_t ABcd8b8a = mkldnn_ABcd8b8a;
+    const format_tag_t aBCd8b8c = mkldnn_aBCd8b8c;
+    const format_tag_t aBCd8c16b2c = mkldnn_aBCd8c16b2c;
+    const format_tag_t aBCd8c8b = mkldnn_aBCd8c8b;
+    const format_tag_t Abcde16a = mkldnn_Abcde16a;
+    const format_tag_t ABcde16a16b = mkldnn_ABcde16a16b;
+    const format_tag_t aBcde16b = mkldnn_aBcde16b;
+    const format_tag_t ABcde16b16a = mkldnn_ABcde16b16a;
+    const format_tag_t aBCde16b16c = mkldnn_aBCde16b16c;
+    const format_tag_t aBCde16c16b = mkldnn_aBCde16c16b;
+    const format_tag_t aBCde2c8b4c = mkldnn_aBCde2c8b4c;
+    const format_tag_t Abcde4a = mkldnn_Abcde4a;
+    const format_tag_t aBcde4b = mkldnn_aBcde4b;
+    const format_tag_t ABcde4b4a = mkldnn_ABcde4b4a;
+    const format_tag_t aBCde4b4c = mkldnn_aBCde4b4c;
+    const format_tag_t aBCde4c16b4c = mkldnn_aBCde4c16b4c;
+    const format_tag_t aBCde4c4b = mkldnn_aBCde4c4b;
+    const format_tag_t Abcde8a = mkldnn_Abcde8a;
+    const format_tag_t ABcde8a8b = mkldnn_ABcde8a8b;
+    const format_tag_t aBcde8b = mkldnn_aBcde8b;
+    const format_tag_t ABcde8b16a2b = mkldnn_ABcde8b16a2b;
+    const format_tag_t aBCde8b16c2b = mkldnn_aBCde8b16c2b;
+    const format_tag_t ABcde8b8a = mkldnn_ABcde8b8a;
+    const format_tag_t aBCde8b8c = mkldnn_aBCde8b8c;
+    const format_tag_t aBCde8c16b2c = mkldnn_aBCde8c16b2c;
+    const format_tag_t aBCde8c8b = mkldnn_aBCde8c8b;
+    const format_tag_t aBcdef16b = mkldnn_aBcdef16b;
+    const format_tag_t aBCdef16b16c = mkldnn_aBCdef16b16c;
+    const format_tag_t aBCdef16c16b = mkldnn_aBCdef16c16b;
+    const format_tag_t aBcdef4b = mkldnn_aBcdef4b;
+    const format_tag_t aBCdef4c4b = mkldnn_aBCdef4c4b;
+    const format_tag_t aBCdef8b8c = mkldnn_aBCdef8b8c;
+    const format_tag_t aBCdef8c16b2c = mkldnn_aBCdef8c16b2c;
+    const format_tag_t aBCdef8c8b = mkldnn_aBCdef8c8b;
+    const format_tag_t aBdc16b = mkldnn_aBdc16b;
+    const format_tag_t aBdc4b = mkldnn_aBdc4b;
+    const format_tag_t aBdc8b = mkldnn_aBdc8b;
+    const format_tag_t aBdec16b = mkldnn_aBdec16b;
+    const format_tag_t aBdec4b = mkldnn_aBdec4b;
+    const format_tag_t aBdec8b = mkldnn_aBdec8b;
+    const format_tag_t aBdefc16b = mkldnn_aBdefc16b;
+    const format_tag_t aBdefc4b = mkldnn_aBdefc4b;
+    const format_tag_t aBdefc8b = mkldnn_aBdefc8b;
+    const format_tag_t Acb16a = mkldnn_Acb16a;
+    const format_tag_t Acb4a = mkldnn_Acb4a;
+    const format_tag_t Acb8a = mkldnn_Acb8a;
+    const format_tag_t aCBd16b16c = mkldnn_aCBd16b16c;
+    const format_tag_t aCBde16b16c = mkldnn_aCBde16b16c;
+    const format_tag_t Acdb16a = mkldnn_Acdb16a;
+    const format_tag_t Acdb4a = mkldnn_Acdb4a;
+    const format_tag_t Acdb8a = mkldnn_Acdb8a;
+    const format_tag_t Acdeb16a = mkldnn_Acdeb16a;
+    const format_tag_t Acdeb4a = mkldnn_Acdeb4a;
+    const format_tag_t Acdeb8a = mkldnn_Acdeb8a;
+    const format_tag_t BAc16a16b = mkldnn_BAc16a16b;
+    const format_tag_t BAcd16a16b = mkldnn_BAcd16a16b;
+    const format_tag_t last = mkldnn_format_tag_last;
+
+    const format_tag_t x = mkldnn_x;
+    const format_tag_t nc = mkldnn_nc;
+    const format_tag_t cn = mkldnn_cn;
+    const format_tag_t ncw = mkldnn_ncw;
+    const format_tag_t nwc = mkldnn_nwc;
+    const format_tag_t nchw = mkldnn_nchw;
+    const format_tag_t nhwc = mkldnn_nhwc;
+    const format_tag_t chwn = mkldnn_chwn;
+    const format_tag_t ncdhw = mkldnn_ncdhw;
+    const format_tag_t ndhwc = mkldnn_ndhwc;
+    const format_tag_t oi = mkldnn_oi;
+    const format_tag_t io = mkldnn_io;
+    const format_tag_t oiw = mkldnn_oiw;
+    const format_tag_t wio = mkldnn_wio;
+    const format_tag_t oihw = mkldnn_oihw;
+    const format_tag_t hwio = mkldnn_hwio;
+    const format_tag_t ihwo = mkldnn_ihwo;
+    const format_tag_t iohw = mkldnn_iohw;
+    const format_tag_t oidhw = mkldnn_oidhw;
+    const format_tag_t dhwio = mkldnn_dhwio;
+    const format_tag_t goiw = mkldnn_goiw;
+    const format_tag_t goihw = mkldnn_goihw;
+    const format_tag_t hwigo = mkldnn_hwigo;
+    const format_tag_t giohw = mkldnn_giohw;
+    const format_tag_t goidhw = mkldnn_goidhw;
+    const format_tag_t tnc = mkldnn_tnc;
+    const format_tag_t ntc = mkldnn_ntc;
+    const format_tag_t ldsnc = mkldnn_ldsnc;
+    const format_tag_t ldigo = mkldnn_ldigo;
+    const format_tag_t ldgoi = mkldnn_ldgoi;
+    const format_tag_t ldgo = mkldnn_ldgo;
+    const format_tag_t nCdhw16c = mkldnn_nCdhw16c;
+    const format_tag_t nCdhw4c = mkldnn_nCdhw4c;
+    const format_tag_t nCdhw8c = mkldnn_nCdhw8c;
+    const format_tag_t nChw16c = mkldnn_nChw16c;
+    const format_tag_t nChw4c = mkldnn_nChw4c;
+    const format_tag_t nChw8c = mkldnn_nChw8c;
+    const format_tag_t nCw16c = mkldnn_nCw16c;
+    const format_tag_t nCw4c = mkldnn_nCw4c;
+    const format_tag_t nCw8c = mkldnn_nCw8c;
+    const format_tag_t IOw16o16i = mkldnn_IOw16o16i;
+    const format_tag_t OIw16i16o = mkldnn_OIw16i16o;
+    const format_tag_t OIw16o16i = mkldnn_OIw16o16i;
+    const format_tag_t Oiw16o = mkldnn_Oiw16o;
+    const format_tag_t OIw4i16o4i = mkldnn_OIw4i16o4i;
+    const format_tag_t OIw4i4o = mkldnn_OIw4i4o;
+    const format_tag_t Oiw4o = mkldnn_Oiw4o;
+    const format_tag_t OIw8i16o2i = mkldnn_OIw8i16o2i;
+    const format_tag_t OIw8i8o = mkldnn_OIw8i8o;
+    const format_tag_t OIw8o16i2o = mkldnn_OIw8o16i2o;
+    const format_tag_t OIw8o8i = mkldnn_OIw8o8i;
+    const format_tag_t Owi16o = mkldnn_Owi16o;
+    const format_tag_t Owi4o = mkldnn_Owi4o;
+    const format_tag_t Owi8o = mkldnn_Owi8o;
+    const format_tag_t IOhw16o16i = mkldnn_IOhw16o16i;
+    const format_tag_t Ohwi16o = mkldnn_Ohwi16o;
+    const format_tag_t Ohwi4o = mkldnn_Ohwi4o;
+    const format_tag_t Ohwi8o = mkldnn_Ohwi8o;
+    const format_tag_t OIhw16i16o = mkldnn_OIhw16i16o;
+    const format_tag_t OIhw16o16i = mkldnn_OIhw16o16i;
+    const format_tag_t Oihw16o = mkldnn_Oihw16o;
+    const format_tag_t OIhw4i16o4i = mkldnn_OIhw4i16o4i;
+    const format_tag_t OIhw4i4o = mkldnn_OIhw4i4o;
+    const format_tag_t Oihw4o = mkldnn_Oihw4o;
+    const format_tag_t OIhw8i16o2i = mkldnn_OIhw8i16o2i;
+    const format_tag_t OIhw8i8o = mkldnn_OIhw8i8o;
+    const format_tag_t OIhw8o16i2o = mkldnn_OIhw8o16i2o;
+    const format_tag_t OIhw8o8i = mkldnn_OIhw8o8i;
+    const format_tag_t Odhwi16o = mkldnn_Odhwi16o;
+    const format_tag_t Odhwi4o = mkldnn_Odhwi4o;
+    const format_tag_t Odhwi8o = mkldnn_Odhwi8o;
+    const format_tag_t OIdhw16i16o = mkldnn_OIdhw16i16o;
+    const format_tag_t OIdhw16o16i = mkldnn_OIdhw16o16i;
+    const format_tag_t Oidhw16o = mkldnn_Oidhw16o;
+    const format_tag_t OIdhw4i4o = mkldnn_OIdhw4i4o;
+    const format_tag_t Oidhw4o = mkldnn_Oidhw4o;
+    const format_tag_t OIdhw8i16o2i = mkldnn_OIdhw8i16o2i;
+    const format_tag_t OIdhw8i8o = mkldnn_OIdhw8i8o;
+    const format_tag_t OIdhw8o8i = mkldnn_OIdhw8o8i;
+    const format_tag_t gIOw16o16i = mkldnn_gIOw16o16i;
+    const format_tag_t Goiw16g = mkldnn_Goiw16g;
+    const format_tag_t gOIw16i16o = mkldnn_gOIw16i16o;
+    const format_tag_t gOIw16o16i = mkldnn_gOIw16o16i;
+    const format_tag_t gOiw16o = mkldnn_gOiw16o;
+    const format_tag_t gOIw4i16o4i = mkldnn_gOIw4i16o4i;
+    const format_tag_t gOIw4i4o = mkldnn_gOIw4i4o;
+    const format_tag_t gOiw4o = mkldnn_gOiw4o;
+    const format_tag_t gOIw8i16o2i = mkldnn_gOIw8i16o2i;
+    const format_tag_t gOIw8i8o = mkldnn_gOIw8i8o;
+    const format_tag_t gOIw8o16i2o = mkldnn_gOIw8o16i2o;
+    const format_tag_t gOIw8o8i = mkldnn_gOIw8o8i;
+    const format_tag_t gOwi16o = mkldnn_gOwi16o;
+    const format_tag_t gOwi4o = mkldnn_gOwi4o;
+    const format_tag_t gOwi8o = mkldnn_gOwi8o;
+    const format_tag_t gIOhw16o16i = mkldnn_gIOhw16o16i;
+    const format_tag_t gOhwi16o = mkldnn_gOhwi16o;
+    const format_tag_t gOhwi4o = mkldnn_gOhwi4o;
+    const format_tag_t gOhwi8o = mkldnn_gOhwi8o;
+    const format_tag_t Goihw16g = mkldnn_Goihw16g;
+    const format_tag_t gOIhw16i16o = mkldnn_gOIhw16i16o;
+    const format_tag_t gOIhw16o16i = mkldnn_gOIhw16o16i;
+    const format_tag_t gOihw16o = mkldnn_gOihw16o;
+    const format_tag_t gOIhw2i8o4i = mkldnn_gOIhw2i8o4i;
+    const format_tag_t gOIhw4i16o4i = mkldnn_gOIhw4i16o4i;
+    const format_tag_t gOIhw4i4o = mkldnn_gOIhw4i4o;
+    const format_tag_t gOIhw4o4i = mkldnn_gOIhw4o4i;
+    const format_tag_t gOihw4o = mkldnn_gOihw4o;
+    const format_tag_t Goihw8g = mkldnn_Goihw8g;
+    const format_tag_t gOIhw8i16o2i = mkldnn_gOIhw8i16o2i;
+    const format_tag_t gOIhw8i8o = mkldnn_gOIhw8i8o;
+    const format_tag_t gOIhw8o16i2o = mkldnn_gOIhw8o16i2o;
+    const format_tag_t gOIhw8o8i = mkldnn_gOIhw8o8i;
+    const format_tag_t gOdhwi16o = mkldnn_gOdhwi16o;
+    const format_tag_t gOdhwi4o = mkldnn_gOdhwi4o;
+    const format_tag_t gOdhwi8o = mkldnn_gOdhwi8o;
+    const format_tag_t gOIdhw16i16o = mkldnn_gOIdhw16i16o;
+    const format_tag_t gOIdhw16o16i = mkldnn_gOIdhw16o16i;
+    const format_tag_t gOidhw16o = mkldnn_gOidhw16o;
+    const format_tag_t gOIdhw4i4o = mkldnn_gOIdhw4i4o;
+    const format_tag_t gOidhw4o = mkldnn_gOidhw4o;
+    const format_tag_t gOIdhw8i16o2i = mkldnn_gOIdhw8i16o2i;
+    const format_tag_t gOIdhw8i8o = mkldnn_gOIdhw8i8o;
+    const format_tag_t gOIdhw8o8i = mkldnn_gOIdhw8o8i;
+}
+
+using memory_extra_flags_t = mkldnn_memory_extra_flags_t;
+namespace memory_extra_flags {
+    const memory_extra_flags_t none = mkldnn_memory_extra_flag_none;
+    const memory_extra_flags_t compensation_conv_s8s8 = mkldnn_memory_extra_flag_compensation_conv_s8s8;
+    const memory_extra_flags_t scale_adjust = mkldnn_memory_extra_flag_scale_adjust;
+}
+
+using padding_kind_t = mkldnn_padding_kind_t;
+namespace padding_kind {
+    const padding_kind_t padding_zero = mkldnn_padding_zero;
+}
+
+using engine_kind_t = mkldnn_engine_kind_t;
+namespace engine_kind {
+    const engine_kind_t any_engine = mkldnn_any_engine;
+    const engine_kind_t cpu = mkldnn_cpu;
+}
+
+using primitive_kind_t = mkldnn_primitive_kind_t;
+namespace primitive_kind {
+    const primitive_kind_t undefined = mkldnn_undefined_primitive;
+    const primitive_kind_t reorder = mkldnn_reorder;
+    const primitive_kind_t concat = mkldnn_concat;
+    const primitive_kind_t sum = mkldnn_sum;
+    const primitive_kind_t convolution = mkldnn_convolution;
+    const primitive_kind_t deconvolution = mkldnn_deconvolution;
+    const primitive_kind_t shuffle = mkldnn_shuffle;
+    const primitive_kind_t eltwise = mkldnn_eltwise;
+    const primitive_kind_t softmax = mkldnn_softmax;
+    const primitive_kind_t pooling = mkldnn_pooling;
+    const primitive_kind_t lrn = mkldnn_lrn;
+    const primitive_kind_t batch_normalization = mkldnn_batch_normalization;
+    const primitive_kind_t inner_product = mkldnn_inner_product;
+    const primitive_kind_t rnn = mkldnn_rnn;
+}
+
+using query_t = mkldnn_query_t;
+namespace query {
+    const query_t undef = mkldnn_query_undef;
+
+    const query_t engine = mkldnn_query_engine;
+    const query_t primitive_kind = mkldnn_query_primitive_kind;
+
+    const query_t num_of_inputs_s32 = mkldnn_query_num_of_inputs_s32;
+    const query_t num_of_outputs_s32 = mkldnn_query_num_of_outputs_s32;
+
+    const query_t time_estimate_f64 = mkldnn_query_time_estimate_f64;
+    const query_t memory_consumption_s64 = mkldnn_query_memory_consumption_s64;
+
+    const query_t scratchpad_engine = mkldnn_query_scratchpad_engine;
+
+    const query_t impl_info_str = mkldnn_query_impl_info_str;
+
+    const query_t some_d = mkldnn_query_some_d;
+    const query_t op_d = mkldnn_query_op_d;
+    const query_t convolution_d = mkldnn_query_convolution_d;
+    const query_t deconvolution_d = mkldnn_query_deconvolution_d;
+    const query_t shuffle_d = mkldnn_query_shuffle_d;
+    const query_t eltwise_d = mkldnn_query_eltwise_d;
+    const query_t softmax_d = mkldnn_query_softmax_d;
+    const query_t pooling_d = mkldnn_query_pooling_d;
+    const query_t lrn_d = mkldnn_query_lrn_d;
+    const query_t batch_normalization_d = mkldnn_query_batch_normalization_d;
+    const query_t inner_product_d = mkldnn_query_inner_product_d;
+    const query_t rnn_d = mkldnn_query_rnn_d;
+
+    const query_t some_md = mkldnn_query_some_md;
+    const query_t src_md = mkldnn_query_src_md;
+    const query_t diff_src_md = mkldnn_query_diff_src_md;
+    const query_t weights_md = mkldnn_query_weights_md;
+    const query_t diff_weights_md = mkldnn_query_diff_weights_md;
+    const query_t dst_md = mkldnn_query_dst_md;
+    const query_t diff_dst_md = mkldnn_query_diff_dst_md;
+
+    const query_t workspace_md = mkldnn_query_workspace_md;
+    const query_t scratchpad_md = mkldnn_query_scratchpad_md;
+}
+
+using blocking_desc_t = mkldnn_blocking_desc_t;
+using rnn_packed_desc_t = mkldnn_rnn_packed_desc_t;
+using wino_desc_t = mkldnn_wino_desc_t;
+using memory_extra_desc_t = mkldnn_memory_extra_desc_t;
+using memory_desc_t = mkldnn_memory_desc_t;
+using convolution_desc_t = mkldnn_convolution_desc_t;
+using deconvolution_desc_t = mkldnn_deconvolution_desc_t;
+using shuffle_desc_t = mkldnn_shuffle_desc_t;
+using pooling_desc_t = mkldnn_pooling_desc_t;
+using eltwise_desc_t = mkldnn_eltwise_desc_t;
+using softmax_desc_t = mkldnn_softmax_desc_t;
+using lrn_desc_t = mkldnn_lrn_desc_t;
+using batch_normalization_desc_t = mkldnn_batch_normalization_desc_t;
+using inner_product_desc_t = mkldnn_inner_product_desc_t;
+
+using rnn_direction_t = mkldnn_rnn_direction_t;
+using rnn_cell_desc_t = mkldnn_rnn_cell_desc_t;
+using rnn_desc_t = mkldnn_rnn_desc_t;
+
+/* C op_desc_t, which eventually are just (void*) */
+using c_op_desc_t = mkldnn_op_desc_t;
+using const_c_op_desc_t = const_mkldnn_op_desc_t;
+
+struct op_desc_t {
+    union {
+        primitive_kind_t kind;
+        convolution_desc_t convolution;
+        deconvolution_desc_t deconvolution;
+        shuffle_desc_t shuffle;
+        pooling_desc_t pooling;
+        eltwise_desc_t eltwise;
+        softmax_desc_t softmax;
+        lrn_desc_t lrn;
+        batch_normalization_desc_t batch_normalization;
+        inner_product_desc_t inner_product;
+        rnn_desc_t rnn;
+    };
+
+    op_desc_t(const primitive_kind_t &_): kind(_) {}
+
+#   define DECL_CTOR_AND_CONVERTERS(c_type, name) \
+    op_desc_t(const c_type &_): name(_) {} \
+    static op_desc_t *convert_from_c(c_type *_) \
+    { return reinterpret_cast<op_desc_t*>(_); } \
+    static const op_desc_t *convert_from_c(const c_type *_) \
+    { return reinterpret_cast<const op_desc_t*>(_); }
+
+    DECL_CTOR_AND_CONVERTERS(convolution_desc_t, convolution);
+    DECL_CTOR_AND_CONVERTERS(shuffle_desc_t, shuffle);
+    DECL_CTOR_AND_CONVERTERS(pooling_desc_t, pooling);
+    DECL_CTOR_AND_CONVERTERS(eltwise_desc_t, eltwise);
+    DECL_CTOR_AND_CONVERTERS(softmax_desc_t, softmax);
+    DECL_CTOR_AND_CONVERTERS(lrn_desc_t, lrn);
+    DECL_CTOR_AND_CONVERTERS(batch_normalization_desc_t, batch_normalization);
+    DECL_CTOR_AND_CONVERTERS(inner_product_desc_t, inner_product);
+    DECL_CTOR_AND_CONVERTERS(rnn_desc_t, rnn);
+
+#   undef DECL_CTOR_AND_CONVERTERS
+};
+
+using engine_t = mkldnn_engine;
+using primitive_desc_iterator_t = mkldnn_primitive_desc_iterator;
+using primitive_desc_t = mkldnn_primitive_desc;
+using primitive_attr_t = mkldnn_primitive_attr;
+using post_ops_t = mkldnn_post_ops;
+using memory_t = mkldnn_memory;
+using primitive_t = mkldnn_primitive;
+
+using primitive_arg_index_t = int;
+
+using stream_flags_t = mkldnn_stream_flags_t;
+namespace stream_flags {
+    const stream_flags_t default_flags = mkldnn_stream_default_flags;
+}
+using stream_t = mkldnn_stream;
+
+/* forward declaration of the internal primitive_desc types */
+struct batch_normalization_bwd_pd_t;
+struct batch_normalization_fwd_pd_t;
+struct batch_normalization_pd_t;
+struct concat_pd_t;
+struct convolution_bwd_data_pd_t;
+struct convolution_bwd_weights_pd_t;
+struct convolution_fwd_pd_t;
+struct convolution_pd_t;
+struct deconvolution_bwd_data_pd_t;
+struct deconvolution_bwd_weights_pd_t;
+struct deconvolution_fwd_pd_t;
+struct deconvolution_pd_t;
+struct eltwise_bwd_pd_t;
+struct eltwise_fwd_pd_t;
+struct eltwise_pd_t;
+struct inner_product_bwd_data_pd_t;
+struct inner_product_bwd_weights_pd_t;
+struct inner_product_fwd_pd_t;
+struct inner_product_pd_t;
+struct lrn_bwd_pd_t;
+struct lrn_fwd_pd_t;
+struct lrn_pd_t;
+struct pooling_bwd_pd_t;
+struct pooling_fwd_pd_t;
+struct pooling_pd_t;
+struct reorder_pd_t;
+struct rnn_bwd_pd_t;
+struct rnn_fwd_pd_t;
+struct rnn_pd_t;
+struct shuffle_pd_t;
+struct softmax_bwd_pd_t;
+struct softmax_fwd_pd_t;
+struct softmax_pd_t;
+struct sum_pd_t;
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 86 - 0
thirdparty/oidn/mkl-dnn/src/common/concat.cpp

@@ -0,0 +1,86 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "concat_pd.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+
+status_t mkldnn_concat_primitive_desc_create(primitive_desc_t **concat_pd,
+        const memory_desc_t *dst_md, int n, int concat_dim,
+        const memory_desc_t *src_mds,
+        const primitive_attr_t *attr,
+        engine_t *engine) {
+    bool args_ok = !any_null(concat_pd, src_mds) && n > 0;
+    if (!args_ok) return invalid_arguments;
+
+    const primitive_attr_t dummy_attr;
+    if (attr == NULL)
+        attr = &dummy_attr;
+
+    const int ndims = src_mds[0].ndims;
+    const dims_t &dims = src_mds[0].dims;
+    const data_type_t dt = src_mds[0].data_type;
+
+    int concat_dim_sz = dims[concat_dim];
+    for (int i = 1; i < n; ++i) {
+        if (src_mds[i].ndims != ndims) return invalid_arguments;
+        for (int d = 0; d < ndims; ++d) {
+            if (d == concat_dim) continue;
+            if (src_mds[i].dims[d] != dims[d])
+                return invalid_arguments;
+        }
+        if (src_mds[i].data_type != dt) return invalid_arguments;
+        concat_dim_sz += src_mds[i].dims[concat_dim];
+    }
+
+    memory_desc_t dummy_dst_md;
+    if (dst_md) {
+        if (dst_md->ndims != ndims) return invalid_arguments;
+        for (int d = 0; d < ndims; ++d) {
+            if (dst_md->dims[d] !=
+                    (d == concat_dim ? concat_dim_sz : dims[d]))
+                return invalid_arguments;
+        }
+    } else {
+        dummy_dst_md = src_mds[0];
+        dummy_dst_md.dims[concat_dim] = concat_dim_sz;
+        dummy_dst_md.format_kind = format_kind::any;
+        dst_md = &dummy_dst_md;
+    }
+
+    auto c_pd = reinterpret_cast<concat_pd_t **>(concat_pd);
+
+    for (auto c = engine->get_concat_implementation_list(); *c; ++c) {
+        if ((*c)(c_pd, engine, attr, dst_md, n, concat_dim, src_mds)
+                == success) {
+            (*c_pd)->init_info();
+            (*c_pd)->init_scratchpad_md();
+            return success;
+        }
+    }
+    return unimplemented;
+}

+ 211 - 0
thirdparty/oidn/mkl-dnn/src/common/concat_pd.hpp

@@ -0,0 +1,211 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CONCAT_PD_HPP
+#define CONCAT_PD_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "primitive_desc.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct concat_pd_t: public primitive_desc_t {
+    concat_pd_t(engine_t *engine, const primitive_attr_t *attr,
+            const memory_desc_t *dst_md, int n, int concat_dim,
+            const memory_desc_t *src_mds)
+        : primitive_desc_t(engine, attr, primitive_kind::concat)
+        , n_(n), concat_dim_(concat_dim), dst_md_(*dst_md)
+    {
+        src_mds_.reserve(n_);
+        for (int i = 0; i < n_; ++i) src_mds_.push_back(src_mds[i]);
+    }
+
+    concat_pd_t(const concat_pd_t &rhs) = default;
+
+    virtual void init_info() override { impl::init_info(this, this->info_); }
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (arg >= MKLDNN_ARG_MULTIPLE_SRC
+                && arg < MKLDNN_ARG_MULTIPLE_SRC + n_inputs())
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_DST)
+            return arg_usage_t::output;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *src_md(int index = 0) const override
+    { return index < n_inputs() ? &src_mds_[index] : nullptr; }
+    virtual const memory_desc_t *dst_md(int index = 0) const override
+    { return index == 0 ? &dst_md_ : nullptr; }
+
+    virtual int n_inputs() const override { return n_; }
+    virtual int n_outputs() const override { return 1; }
+
+    int concat_dim() const { return concat_dim_; }
+
+    const memory_desc_t *src_image_md(int index = 0) const
+    { return index < n_inputs() ? &src_image_mds_[index] : nullptr; }
+
+protected:
+    int n_, concat_dim_;
+    memory_desc_t dst_md_;
+    nstl::vector<memory_desc_t> src_mds_;
+
+    /* contains images of srcs in the dst memory (if possible)
+     * Lives here to simplify some implementations. An implementation might
+     * use this auxiliary array iff init() returned success */
+    nstl::vector<memory_desc_t> src_image_mds_;
+
+protected:
+    /* inits src_image_mds_ and dst_md_ in simple cases. The call may fail */
+    status_t init() {
+        bool ok = true
+            && set_default_params() == status::success
+            && attr()->has_default_values();
+        if (!ok) return status::unimplemented;
+
+        for (int i = 0; i < n_; ++i) {
+            const memory_desc_wrapper i_d(&src_mds_[i]);
+            if (!i_d.is_blocking_desc() || i_d.is_additional_buffer())
+                return status::unimplemented;
+        }
+
+        const int ndims = dst_md_.ndims;
+        int current_concat_dim_offset = 0;
+        for (int i = 0; i < n_; ++i) {
+            const int dim = src_mds_[i].dims[concat_dim_];
+            dims_t dims, offsets = {};
+            utils::array_copy(dims, dst_md_.dims, ndims);
+            dims[concat_dim_] = dim;
+            offsets[concat_dim_] = current_concat_dim_offset;
+
+            memory_desc_t src_img_d;
+            status_t status = mkldnn_memory_desc_init_submemory(&src_img_d,
+                    &dst_md_, dims, offsets);
+            if (status != status::success) return status;
+            src_image_mds_.push_back(src_img_d);
+            current_concat_dim_offset += dim;
+        }
+
+        return status::success;
+    }
+
+    status_t set_default_params() {
+        if (dst_md_.format_kind != format_kind::any)
+            return status::success;
+
+        const int ndims = dst_md_.ndims;
+
+        /* The stupidest ever heuristics (but not the same as we had before):
+         *  - Pick the first non-plain format;
+         *  - If all formats are plain or it is not possible to create a
+         *    blocked format for the output, pick the format of the plain input
+         *  - If this fails as well, use plain layout (abcd...)
+         */
+        status_t status = status::unimplemented;
+        for (int i = 0; i < n_; ++i) {
+            const memory_desc_wrapper src_d(src_mds_[i]);
+            if (src_d.is_blocking_desc() && !src_d.is_plain()) {
+                status = memory_desc_init_by_blocking_desc(dst_md_,
+                        src_d.blocking_desc());
+                if (status == status::success) break;
+            }
+        }
+
+        if (status == status::success) {
+            /* check if we can create a sub-memory for the dst */
+            bool desired_format_ok = true;
+            int current_concat_dim_offset = 0;
+            for (int i = 0; i < n_; ++i) {
+                const int dim = src_mds_[i].dims[concat_dim_];
+                dims_t dims, offsets = {};
+                utils::array_copy(dims, dst_md_.dims, ndims);
+                dims[concat_dim_] = dim;
+                offsets[concat_dim_] = current_concat_dim_offset;
+
+                memory_desc_t src_img_d;
+                status_t status = mkldnn_memory_desc_init_submemory(&src_img_d,
+                        &dst_md_, dims, offsets);
+                if (status != status::success) {
+                    desired_format_ok = false;
+                    break;
+                }
+                current_concat_dim_offset += dim;
+            }
+
+            if (!desired_format_ok)
+                status = status::unimplemented;
+        }
+
+        /* if no success so far, try using the format of the first plain input */
+        if (status != status::success) {
+            for (int i = 0; i < n_; ++i) {
+                const memory_desc_wrapper src_d(src_mds_[i]);
+                if (src_d.is_blocking_desc() && src_d.is_plain()) {
+                    status = memory_desc_init_by_blocking_desc(dst_md_,
+                            memory_desc_wrapper(src_mds_[0]).blocking_desc());
+                    if (status == status::success) return status;
+                }
+            }
+        }
+
+        /* the last line of defense: use plain abcd... format */
+        if (status != status::success)
+            status = memory_desc_init_by_strides(dst_md_, nullptr);
+
+        return status;
+    }
+};
+
+#define DECLARE_CONCAT_PD_t(impl_name, ...) \
+    static status_t create(concat_pd_t **concat_pd, \
+            engine_t *engine, const primitive_attr_t *attr, \
+            const memory_desc_t *dst_md, int n, int concat_dim, \
+            const memory_desc_t *src_mds) { \
+        using namespace status; \
+        auto _pd = new pd_t(engine, attr, dst_md, n, concat_dim, src_mds); \
+        if (_pd == nullptr) return out_of_memory; \
+        if (_pd->init() != success) { delete _pd; return unimplemented; } \
+        return safe_ptr_assign<concat_pd_t>(*concat_pd, _pd); \
+    } \
+    virtual status_t create_primitive(primitive_t **p) const override { \
+        double ms = get_msec(); \
+        auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \
+        ms = get_msec() - ms; \
+        if (mkldnn_verbose()->level >= 2) { \
+            printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
+            fflush(0); \
+        } \
+        return ret; \
+    } \
+    virtual pd_t *clone() const override { return new pd_t(*this); } \
+    virtual const char *name() const override { return impl_name; } \
+
+#define DECLARE_CONCAT_PD_T(impl_name, ...) \
+    DECLARE_CONCAT_PD_t(impl_name, __VA_ARGS__)
+
+}
+}
+
+#endif

+ 200 - 0
thirdparty/oidn/mkl-dnn/src/common/convolution.cpp

@@ -0,0 +1,200 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace mkldnn {
+namespace impl {
+status_t conv_desc_init(convolution_desc_t *conv_desc,
+        prop_kind_t prop_kind, alg_kind_t alg_kind,
+        const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
+        const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
+        const dims_t strides, const dims_t dilates,
+        const dims_t padding_l, const dims_t padding_r,
+        padding_kind_t padding_kind) {
+    bool args_ok = true
+        && !any_null(conv_desc, src_desc, weights_desc, dst_desc, strides,
+                padding_l)
+        && one_of(alg_kind, convolution_auto, convolution_direct, convolution_winograd)
+        && one_of(padding_kind, padding_kind::padding_zero);
+    if (!args_ok) return invalid_arguments;
+
+    if (padding_r == nullptr) padding_r = padding_l;
+
+    auto cd = convolution_desc_t();
+    cd.primitive_kind = primitive_kind::convolution;
+    cd.prop_kind = prop_kind;
+    cd.alg_kind = alg_kind;
+
+    cd.diff_src_desc = cd.src_desc = zero_md();
+    cd.diff_dst_desc = cd.dst_desc = zero_md();
+    cd.diff_weights_desc = cd.weights_desc = zero_md();
+    cd.diff_bias_desc = cd.bias_desc = zero_md();
+
+    const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
+    const bool with_bias =
+        bias_desc && bias_desc->format_kind != format_kind::undef;
+    const bool with_groups = weights_desc->ndims == src_desc->ndims + 1;
+
+    (prop_kind == backward_data ? cd.diff_src_desc : cd.src_desc) = *src_desc;
+    (is_fwd ? cd.dst_desc : cd.diff_dst_desc)  = *dst_desc;
+    (prop_kind == backward_weights ? cd.diff_weights_desc : cd.weights_desc) =
+        *weights_desc;
+    if (with_bias)
+        (prop_kind == backward_weights ? cd.diff_bias_desc : cd.bias_desc) =
+            *bias_desc;
+
+    int sp_dims = src_desc->ndims - 2;
+    utils::array_copy(cd.strides, strides, sp_dims);
+    utils::array_copy(cd.padding[0], padding_l, sp_dims);
+    utils::array_copy(cd.padding[1], padding_r, sp_dims);
+    if (dilates)
+        utils::array_copy(cd.dilates, dilates, sp_dims);
+    else
+        utils::array_set(cd.dilates, 0, sp_dims);
+
+    cd.padding_kind = padding_kind;
+    cd.accum_data_type = types::default_accum_data_type(src_desc->data_type,
+            weights_desc->data_type, dst_desc->data_type, prop_kind);
+
+    const int g = with_groups ? weights_desc->dims[0] : 1;
+    const int bias_dim = prop_kind == backward_data
+        ? src_desc->dims[1]
+        : dst_desc->dims[1];
+
+    bool consistency = true
+        && memory_desc_wrapper(weights_desc).nelems()
+        && src_desc->ndims == dst_desc->ndims
+        && utils::one_of(src_desc->ndims, 3, 4, 5)
+        && utils::one_of(weights_desc->ndims, src_desc->ndims,
+                src_desc->ndims + 1)
+        && (with_bias ? bias_desc->ndims == 1 : true)
+        && (with_bias ? bias_desc->dims[0] == bias_dim : true)
+        && src_desc->dims[0] == dst_desc->dims[0]
+        && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1]
+        && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0];
+    for (int i = 2; i < src_desc->ndims; ++i)
+    {
+        int src = src_desc->dims[i];
+        int ker = weights_desc->dims[with_groups + i];
+        int dil = cd.dilates[i - 2];
+        int pad_l = padding_l[i - 2];
+        int pad_r = padding_r[i - 2];
+        int str = strides[i - 2];
+        int dst = dst_desc->dims[i];
+        int ker_range = 1 + (ker - 1) * (dil + 1);
+
+        if (str < 1) return invalid_arguments;
+        consistency = consistency
+            && dil >= 0
+            && pad_l >= 0
+            && pad_r + str > 0
+            && (src - ker_range + pad_l + pad_r) / str + 1 == dst;
+    }
+    if (!consistency) return invalid_arguments;
+
+    *conv_desc = cd;
+    return success;
+}
+}
+}
+
+status_t mkldnn_convolution_forward_desc_init(convolution_desc_t *conv_desc,
+        prop_kind_t prop_kind, alg_kind_t alg_kind,
+        const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
+        const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
+        const dims_t strides, const dims_t padding_l, const dims_t padding_r,
+        padding_kind_t padding_kind) {
+    if (!one_of(prop_kind, forward_training, forward_inference))
+        return invalid_arguments;
+    return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc,
+            weights_desc, bias_desc, dst_desc, strides, nullptr,
+            padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_dilated_convolution_forward_desc_init(
+        convolution_desc_t *conv_desc, prop_kind_t prop_kind,
+        alg_kind_t alg_kind, const memory_desc_t *src_desc,
+        const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
+        const memory_desc_t *dst_desc, const dims_t strides,
+        const dims_t dilates, const dims_t padding_l,
+        const dims_t padding_r, padding_kind_t padding_kind) {
+    if (!one_of(prop_kind, forward_training, forward_inference))
+        return invalid_arguments;
+    return mkldnn::impl::conv_desc_init(conv_desc, prop_kind, alg_kind, src_desc,
+            weights_desc, bias_desc, dst_desc, strides, dilates,
+            padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_convolution_backward_data_desc_init(
+        convolution_desc_t *conv_desc, alg_kind_t alg_kind,
+        const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
+        const memory_desc_t *diff_dst_desc, const dims_t strides,
+        const dims_t padding_l, const dims_t padding_r,
+        padding_kind_t padding_kind) {
+    return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc,
+            weights_desc, nullptr, diff_dst_desc, strides, nullptr,
+            padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_dilated_convolution_backward_data_desc_init(
+        convolution_desc_t *conv_desc, alg_kind_t alg_kind,
+        const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
+        const memory_desc_t *diff_dst_desc, const dims_t strides,
+        const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
+        padding_kind_t padding_kind) {
+    return mkldnn::impl::conv_desc_init(conv_desc, backward_data, alg_kind, diff_src_desc,
+            weights_desc, nullptr, diff_dst_desc, strides, dilates,
+            padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_convolution_backward_weights_desc_init(
+        convolution_desc_t *conv_desc, alg_kind_t alg_kind,
+        const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
+        const memory_desc_t *diff_bias_desc,
+        const memory_desc_t *diff_dst_desc, const dims_t strides,
+        const dims_t padding_l, const dims_t padding_r,
+        padding_kind_t padding_kind) {
+    return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc,
+            diff_weights_desc, diff_bias_desc, diff_dst_desc, strides,
+            nullptr, padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_dilated_convolution_backward_weights_desc_init(
+        convolution_desc_t *conv_desc, alg_kind_t alg_kind,
+        const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
+        const memory_desc_t *diff_bias_desc,
+        const memory_desc_t *diff_dst_desc, const dims_t strides,
+        const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
+        padding_kind_t padding_kind) {
+    return mkldnn::impl::conv_desc_init(conv_desc, backward_weights, alg_kind, src_desc,
+            diff_weights_desc, diff_bias_desc, diff_dst_desc, strides,
+            dilates, padding_l, padding_r, padding_kind);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 56 - 0
thirdparty/oidn/mkl-dnn/src/common/convolution_pd.cpp

@@ -0,0 +1,56 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "utils.hpp"
+
+#include "convolution_pd.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+using namespace prop_kind;
+
+memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc) {
+    return desc->prop_kind == backward_data
+        ? &desc->diff_src_desc : &desc->src_desc;
+}
+
+memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc) {
+    return desc->prop_kind == backward_weights
+        ? &desc->diff_weights_desc : &desc->weights_desc;
+}
+
+memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc) {
+    return desc->prop_kind == backward_weights
+        ? &desc->diff_bias_desc : &desc->bias_desc;
+}
+
+memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc) {
+    return utils::one_of(desc->prop_kind, forward_inference, forward_training)
+        ? &desc->dst_desc : &desc->diff_dst_desc;
+}
+
+const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc)
+{ return conv_prop_invariant_src_d(const_cast<convolution_desc_t *>(desc)); }
+const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc)
+{ return conv_prop_invariant_wei_d(const_cast<convolution_desc_t *>(desc)); }
+const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc)
+{ return conv_prop_invariant_bia_d(const_cast<convolution_desc_t *>(desc)); }
+const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc)
+{ return conv_prop_invariant_dst_d(const_cast<convolution_desc_t *>(desc)); }
+
+}
+}

+ 348 - 0
thirdparty/oidn/mkl-dnn/src/common/convolution_pd.hpp

@@ -0,0 +1,348 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef CONVOLUTION_PD_HPP
+#define CONVOLUTION_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+status_t conv_desc_init(convolution_desc_t *conv_desc,
+        prop_kind_t prop_kind, alg_kind_t alg_kind,
+        const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
+        const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
+        const dims_t strides, const dims_t dilates,
+        const dims_t padding_l, const dims_t padding_r,
+        padding_kind_t padding_kind);
+
+memory_desc_t *conv_prop_invariant_src_d(convolution_desc_t *desc);
+memory_desc_t *conv_prop_invariant_wei_d(convolution_desc_t *desc);
+memory_desc_t *conv_prop_invariant_bia_d(convolution_desc_t *desc);
+memory_desc_t *conv_prop_invariant_dst_d(convolution_desc_t *desc);
+const memory_desc_t *conv_prop_invariant_src_d(const convolution_desc_t *desc);
+const memory_desc_t *conv_prop_invariant_wei_d(const convolution_desc_t *desc);
+const memory_desc_t *conv_prop_invariant_bia_d(const convolution_desc_t *desc);
+const memory_desc_t *conv_prop_invariant_dst_d(const convolution_desc_t *desc);
+
+struct convolution_fwd_pd_t;
+
+struct convolution_pd_t: public primitive_desc_t {
+    static constexpr auto base_pkind = primitive_kind::convolution;
+
+    convolution_pd_t(engine_t *engine,
+            const convolution_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const convolution_fwd_pd_t *hint_fwd_pd)
+        : primitive_desc_t(engine, attr, base_pkind)
+        , desc_(*adesc)
+        , hint_fwd_pd_(hint_fwd_pd)
+    {}
+
+    const convolution_desc_t *desc() const { return &desc_; }
+    virtual const op_desc_t *op_desc() const override
+    { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+    virtual void init_info() override { impl::init_info(this, this->info_); }
+
+    virtual status_t query(query_t what, int idx, void *result) const override {
+        switch (what) {
+        case pkind_traits<base_pkind>::query_d:
+            *(const convolution_desc_t**)result = desc(); break;
+        default: return primitive_desc_t::query(what, idx, result);
+        }
+        return status::success;
+    }
+
+    /* common conv aux functions */
+
+    dim_t MB() const { return _src_md()->dims[0]; }
+
+    dim_t IC() const { return _src_md()->dims[1]; }
+    dim_t OC() const { return _dst_md()->dims[1]; }
+    dim_t G() const { return with_groups() ? _wei_md()->dims[0] : 1; }
+
+    dim_t ID() const { return ndims() >= 5 ? _src_md()->dims[ndims() - 3] : 1; }
+    dim_t IH() const { return ndims() >= 4 ? _src_md()->dims[ndims() - 2] : 1; }
+    dim_t IW() const { return _src_md()->dims[ndims() - 1]; }
+
+    dim_t OD() const { return ndims() >= 5 ? _dst_md()->dims[ndims() - 3] : 1; }
+    dim_t OH() const { return ndims() >= 4 ? _dst_md()->dims[ndims() - 2] : 1; }
+    dim_t OW() const { return _dst_md()->dims[ndims() - 1]; }
+
+    dim_t KD() const { return ndims() >= 5 ? _wei_md()->dims[ndims() + with_groups() - 3] : 1; }
+    dim_t KH() const { return ndims() >= 4 ? _wei_md()->dims[ndims() + with_groups() - 2] : 1; }
+    dim_t KW() const { return _wei_md()->dims[ndims() + with_groups() - 1]; }
+
+    dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
+    dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
+    dim_t KSW() const { return desc_.strides[ndims() - 3]; }
+
+    dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; }
+    dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; }
+    dim_t KDW() const { return desc_.dilates[ndims() - 3]; }
+
+    dim_t padFront() const { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
+    dim_t padBack() const { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
+    dim_t padT() const { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
+    dim_t padB() const { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
+    dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
+    dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
+
+    int ndims() const { return _src_md()->ndims; }
+
+    bool with_bias() const { return !memory_desc_wrapper(*_bia_md()).is_zero(); }
+    bool with_groups() const { return _wei_md()->ndims == ndims() + 1; }
+
+    bool is_fwd() const {
+        return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+                prop_kind::forward_inference);
+    }
+
+    bool has_zero_dim_memory() const {
+        const auto s_d = memory_desc_wrapper(*_src_md());
+        const auto d_d = memory_desc_wrapper(*_dst_md());
+        return s_d.has_zero_dim() || d_d.has_zero_dim();
+    }
+
+protected:
+    convolution_desc_t desc_;
+    const convolution_fwd_pd_t *hint_fwd_pd_;
+
+    bool set_default_formats_common_template(
+            memory_desc_t &src_md, format_tag_t src_tag,
+            memory_desc_t &wei_md, format_tag_t wei_tag,
+            memory_desc_t &dst_md, format_tag_t dst_tag,
+            memory_desc_t &bia_md) {
+        using namespace format_tag;
+
+#       define IS_OK(f) \
+        do { if ((f) != status::success) return false; } while(0)
+        if (src_md.format_kind == format_kind::any
+                && !utils::one_of(src_tag, any, undef))
+            IS_OK(memory_desc_init_by_tag(src_md, src_tag));
+        if (dst_md.format_kind == format_kind::any
+                && !utils::one_of(dst_tag, any, undef))
+            IS_OK(memory_desc_init_by_tag(dst_md, dst_tag));
+        if (wei_md.format_kind == format_kind::any
+                && !utils::one_of(wei_tag, any, undef))
+            IS_OK(memory_desc_init_by_tag(wei_md, wei_tag));
+        if (with_bias() && bia_md.format_kind == format_kind::any)
+            IS_OK(memory_desc_init_by_tag(bia_md, x));
+#       undef IS_OK
+
+        return true;
+    }
+
+    bool set_default_alg_kind(alg_kind_t alg_kind) {
+        assert(utils::one_of(alg_kind, alg_kind::convolution_direct,
+                    alg_kind::convolution_winograd));
+        if (desc_.alg_kind == alg_kind::convolution_auto)
+            desc_.alg_kind = alg_kind;
+        return desc_.alg_kind == alg_kind;
+    }
+
+    bool expect_data_types(data_type_t src_dt, data_type_t wei_dt,
+            data_type_t bia_dt, data_type_t dst_dt, data_type_t acc_dt) const {
+        bool ok = true
+            && (src_dt == data_type::undef || _src_md()->data_type == src_dt)
+            && (wei_dt == data_type::undef || _wei_md()->data_type == wei_dt)
+            && (dst_dt == data_type::undef || _dst_md()->data_type == dst_dt)
+            && (acc_dt == data_type::undef || desc_.accum_data_type == acc_dt);
+        if (with_bias() && bia_dt != data_type::undef)
+            ok = ok && _bia_md()->data_type == bia_dt;
+        return ok;
+    }
+
+private:
+    const memory_desc_t *_src_md() const { return conv_prop_invariant_src_d(&desc_); }
+    const memory_desc_t *_wei_md() const { return conv_prop_invariant_wei_d(&desc_); }
+    const memory_desc_t *_bia_md() const { return conv_prop_invariant_bia_d(&desc_); }
+    const memory_desc_t *_dst_md() const { return conv_prop_invariant_dst_d(&desc_); }
+};
+
+struct convolution_fwd_pd_t: public convolution_pd_t {
+    typedef convolution_fwd_pd_t base_class;
+    typedef convolution_fwd_pd_t hint_class;
+
+    convolution_fwd_pd_t(engine_t *engine,
+            const convolution_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const convolution_fwd_pd_t *hint_fwd_pd)
+        : convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
+        , src_md_(desc_.src_desc)
+        , weights_md_(desc_.weights_desc)
+        , bias_md_(desc_.bias_desc)
+        , dst_md_(desc_.dst_desc)
+    {}
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_BIAS && with_bias())
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_DST)
+            return arg_usage_t::output;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *src_md(int index = 0) const override
+    { return index == 0 ? &src_md_ : nullptr; }
+    virtual const memory_desc_t *dst_md(int index = 0) const override
+    { return index == 0 ? &dst_md_ : nullptr; }
+    virtual const memory_desc_t *weights_md(int index = 0) const override {
+        if (index == 0) return &weights_md_;
+        if (index == 1 && with_bias()) return &bias_md_;
+        return nullptr;
+    }
+
+    virtual int n_inputs() const override { return 2 + with_bias(); }
+    virtual int n_outputs() const override { return 1; }
+
+protected:
+    memory_desc_t src_md_;
+    memory_desc_t weights_md_;
+    memory_desc_t bias_md_;
+    memory_desc_t dst_md_;
+
+    bool set_default_formats_common(format_tag_t src_tag,
+            format_tag_t wei_tag, format_tag_t dst_tag) {
+        return set_default_formats_common_template(src_md_, src_tag,
+                weights_md_, wei_tag, dst_md_, dst_tag, bias_md_);
+    }
+};
+
+struct convolution_bwd_data_pd_t: public convolution_pd_t {
+    typedef convolution_bwd_data_pd_t base_class;
+    typedef convolution_fwd_pd_t hint_class;
+
+    convolution_bwd_data_pd_t(engine_t *engine,
+            const convolution_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const convolution_fwd_pd_t *hint_fwd_pd)
+        : convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
+        , diff_src_md_(desc_.diff_src_desc)
+        , weights_md_(desc_.weights_desc)
+        , bias_md_(desc_.bias_desc)
+        , diff_dst_md_(desc_.diff_dst_desc)
+    {}
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_DIFF_SRC)
+            return arg_usage_t::output;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *diff_src_md(int index = 0) const override
+    { return index == 0 ? &diff_src_md_ : nullptr; }
+    virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+    { return index == 0 ? &diff_dst_md_ : nullptr; }
+    virtual const memory_desc_t *weights_md(int index = 0) const override {
+        if (index == 0) return &weights_md_;
+        if (index == 1 && with_bias()) return &bias_md_;
+        return nullptr;
+    }
+
+    virtual int n_inputs() const override { return 2 + with_bias(); }
+    virtual int n_outputs() const override { return 1; }
+
+    virtual bool support_bias() const { return false; }
+
+protected:
+    memory_desc_t diff_src_md_;
+    memory_desc_t weights_md_;
+    memory_desc_t bias_md_;
+    memory_desc_t diff_dst_md_;
+
+    bool set_default_formats_common(format_tag_t diff_src_tag,
+            format_tag_t wei_tag, format_tag_t diff_dst_tag) {
+        return set_default_formats_common_template(diff_src_md_, diff_src_tag,
+                weights_md_, wei_tag, diff_dst_md_, diff_dst_tag, bias_md_);
+    }
+};
+
+struct convolution_bwd_weights_pd_t: public convolution_pd_t {
+    typedef convolution_bwd_weights_pd_t base_class;
+    typedef convolution_fwd_pd_t hint_class;
+
+    convolution_bwd_weights_pd_t(engine_t *engine,
+            const convolution_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const convolution_fwd_pd_t *hint_fwd_pd)
+        : convolution_pd_t(engine, adesc, attr, hint_fwd_pd)
+        , src_md_(desc_.src_desc)
+        , diff_weights_md_(desc_.diff_weights_desc)
+        , diff_bias_md_(desc_.diff_bias_desc)
+        , diff_dst_md_(desc_.diff_dst_desc)
+    {}
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
+            return arg_usage_t::output;
+
+        if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
+            return arg_usage_t::output;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *src_md(int index = 0) const override
+    { return index == 0 ? &src_md_ : nullptr; }
+    virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+    { return index == 0 ? &diff_dst_md_ : nullptr; }
+    virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
+        if (index == 0) return &diff_weights_md_;
+        if (index == 1 && with_bias()) return &diff_bias_md_;
+        return nullptr;
+    }
+
+    virtual int n_inputs() const override { return 2; }
+    virtual int n_outputs() const override { return 1 + with_bias(); }
+
+protected:
+    memory_desc_t src_md_;
+    memory_desc_t diff_weights_md_;
+    memory_desc_t diff_bias_md_;
+    memory_desc_t diff_dst_md_;
+
+    bool set_default_formats_common(format_tag_t src_tag,
+            format_tag_t diff_wei_tag, format_tag_t diff_dst_tag) {
+        return set_default_formats_common_template(src_md_, src_tag,
+                diff_weights_md_, diff_wei_tag, diff_dst_md_, diff_dst_tag,
+                diff_bias_md_);
+    }
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 188 - 0
thirdparty/oidn/mkl-dnn/src/common/deconvolution.cpp

@@ -0,0 +1,188 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn.h"
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t deconv_desc_init(deconvolution_desc_t *deconv_desc,
+        prop_kind_t prop_kind, alg_kind_t alg_kind,
+        const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
+        const memory_desc_t *bias_desc, const memory_desc_t *dst_desc,
+        const dims_t strides, const dims_t dilates, const dims_t padding_l,
+        const dims_t padding_r, padding_kind_t padding_kind) {
+    bool args_ok = true
+            && !any_null(deconv_desc, src_desc, weights_desc, dst_desc, strides,
+                           padding_l)
+            && one_of(alg_kind, deconvolution_direct, deconvolution_winograd)
+            && one_of(padding_kind, padding_kind::padding_zero);
+    if (!args_ok)
+        return invalid_arguments;
+
+    if (padding_r == nullptr)
+        padding_r = padding_l;
+
+    auto dd = deconvolution_desc_t();
+    dd.primitive_kind = primitive_kind::deconvolution;
+    dd.prop_kind = prop_kind;
+    dd.alg_kind = alg_kind;
+
+    dd.diff_src_desc = dd.src_desc = zero_md();
+    dd.diff_dst_desc = dd.dst_desc = zero_md();
+    dd.diff_weights_desc = dd.weights_desc = zero_md();
+    dd.diff_bias_desc = dd.bias_desc = zero_md();
+
+    const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
+    const bool with_bias
+            = bias_desc && bias_desc->format_kind != format_kind::undef;
+    const bool with_groups = weights_desc->ndims == src_desc->ndims + 1;
+
+    (prop_kind == backward_data ? dd.diff_src_desc : dd.src_desc) = *src_desc;
+    (is_fwd ? dd.dst_desc : dd.diff_dst_desc) = *dst_desc;
+    (prop_kind == backward_weights ? dd.diff_weights_desc : dd.weights_desc)
+            = *weights_desc;
+    if (with_bias)
+        (prop_kind == backward_weights ? dd.diff_bias_desc : dd.bias_desc)
+                = *bias_desc;
+
+    int sp_dims = src_desc->ndims - 2;
+    utils::array_copy(dd.strides, strides, sp_dims);
+    utils::array_copy(dd.padding[0], padding_l, sp_dims);
+    utils::array_copy(dd.padding[1], padding_r, sp_dims);
+    if (dilates)
+        utils::array_copy(dd.dilates, dilates, sp_dims);
+    else
+        utils::array_set(dd.dilates, 0, sp_dims);
+
+    dd.padding_kind = padding_kind;
+    dd.accum_data_type = types::default_accum_data_type(src_desc->data_type,
+            weights_desc->data_type, dst_desc->data_type, prop_kind);
+
+    const int g = with_groups ? weights_desc->dims[0] : 1;
+    bool consistency = true
+            && src_desc->ndims == dst_desc->ndims
+            && utils::one_of(src_desc->ndims, 3, 4, 5)
+            && utils::one_of(weights_desc->ndims, src_desc->ndims,
+                    src_desc->ndims + 1)
+            && (with_bias ? bias_desc->ndims == 1 : true)
+            && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true)
+            && src_desc->dims[0] == dst_desc->dims[0]
+            && src_desc->dims[1] == g * weights_desc->dims[with_groups + 1]
+            && dst_desc->dims[1] == g * weights_desc->dims[with_groups + 0];
+    for (int i = 2; i < src_desc->ndims; ++i) {
+        int src = src_desc->dims[i];
+        int ker = weights_desc->dims[with_groups + i];
+        int dil = dd.dilates[i - 2];
+        int pad = padding_l[i - 2] + padding_r[i - 2];
+        int str = strides[i - 2];
+        int dst = dst_desc->dims[i];
+        int ker_range = 1 + (ker - 1) * (dil + 1);
+
+        consistency
+                = consistency && (dst - ker_range + pad) / str + 1 == src;
+    }
+    if (!consistency)
+        return invalid_arguments;
+
+    *deconv_desc = dd;
+    return success;
+}
+}
+
+status_t mkldnn_deconvolution_forward_desc_init(
+        deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind,
+        alg_kind_t alg_kind, const memory_desc_t *src_desc,
+        const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
+        const memory_desc_t *dst_desc, const dims_t strides,
+        const dims_t padding_l, const dims_t padding_r,
+        padding_kind_t padding_kind) {
+    if (!one_of(prop_kind, forward_training, forward_inference))
+        return invalid_arguments;
+    return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc,
+            weights_desc, bias_desc, dst_desc, strides, nullptr, padding_l,
+            padding_r, padding_kind);
+}
+
+status_t mkldnn_dilated_deconvolution_forward_desc_init(
+        deconvolution_desc_t *deconv_desc, prop_kind_t prop_kind,
+        alg_kind_t alg_kind, const memory_desc_t *src_desc,
+        const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
+        const memory_desc_t *dst_desc, const dims_t strides,
+        const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
+        padding_kind_t padding_kind) {
+    if (!one_of(prop_kind, forward_training, forward_inference))
+        return invalid_arguments;
+    return deconv_desc_init(deconv_desc, prop_kind, alg_kind, src_desc,
+            weights_desc, bias_desc, dst_desc, strides, dilates, padding_l,
+            padding_r, padding_kind);
+}
+
+status_t mkldnn_deconvolution_backward_data_desc_init(
+        deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
+        const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
+        const memory_desc_t *diff_dst_desc, const dims_t strides,
+        const dims_t padding_l, const dims_t padding_r,
+        padding_kind_t padding_kind) {
+    return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc,
+            weights_desc, nullptr, diff_dst_desc, strides, nullptr, padding_l,
+            padding_r, padding_kind);
+}
+
+status_t mkldnn_dilated_deconvolution_backward_data_desc_init(
+        deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
+        const memory_desc_t *diff_src_desc, const memory_desc_t *weights_desc,
+        const memory_desc_t *diff_dst_desc, const dims_t strides,
+        const dims_t dilates, const dims_t padding_l, const dims_t padding_r,
+        padding_kind_t padding_kind) {
+    return deconv_desc_init(deconv_desc, backward_data, alg_kind, diff_src_desc,
+            weights_desc, nullptr, diff_dst_desc, strides,dilates, padding_l,
+            padding_r, padding_kind);
+}
+
+status_t mkldnn_deconvolution_backward_weights_desc_init(
+        deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
+        const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
+        const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc,
+        const dims_t strides, const dims_t padding_l, const dims_t padding_r,
+        padding_kind_t padding_kind) {
+    return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc,
+            diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, nullptr,
+            padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_dilated_deconvolution_backward_weights_desc_init(
+        deconvolution_desc_t *deconv_desc, alg_kind_t alg_kind,
+        const memory_desc_t *src_desc, const memory_desc_t *diff_weights_desc,
+        const memory_desc_t *diff_bias_desc, const memory_desc_t *diff_dst_desc,
+        const dims_t strides, const dims_t dilates, const dims_t padding_l,
+        const dims_t padding_r, padding_kind_t padding_kind) {
+    return deconv_desc_init(deconv_desc, backward_weights, alg_kind, src_desc,
+            diff_weights_desc, diff_bias_desc, diff_dst_desc, strides, dilates,
+            padding_l, padding_r, padding_kind);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 293 - 0
thirdparty/oidn/mkl-dnn/src/common/deconvolution_pd.hpp

@@ -0,0 +1,293 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef DECONVOLUTION_PD_HPP
+#define DECONVOLUTION_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "convolution_pd.hpp"
+#include "primitive_desc.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct deconvolution_fwd_pd_t;
+
+struct deconvolution_pd_t: public primitive_desc_t {
+    static constexpr auto base_pkind = primitive_kind::deconvolution;
+
+    deconvolution_pd_t(engine_t *engine,
+            const deconvolution_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const deconvolution_fwd_pd_t *hint_fwd_pd)
+        : primitive_desc_t(engine, attr, base_pkind)
+        , desc_(*adesc)
+        , hint_fwd_pd_(hint_fwd_pd)
+    {}
+
+    const deconvolution_desc_t *desc() const { return &desc_; }
+    virtual const op_desc_t *op_desc() const override
+    { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+    virtual void init_info() override { impl::init_info(this, this->info_); }
+
+    virtual status_t query(query_t what, int idx, void *result) const override {
+        switch (what) {
+        case pkind_traits<base_pkind>::query_d:
+            *(const deconvolution_desc_t **)result = desc();
+            break;
+        default: return primitive_desc_t::query(what, idx, result);
+        }
+        return status::success;
+    }
+
+    /* common deconv aux functions (note that conv_desc_t == deconv_desc_t) */
+
+    dim_t MB() const { return conv_prop_invariant_src_d(&desc_)->dims[0]; }
+
+    dim_t IC() const { return conv_prop_invariant_src_d(&desc_)->dims[1]; }
+    dim_t OC() const { return conv_prop_invariant_dst_d(&desc_)->dims[1]; }
+    dim_t G() const
+    { return with_groups() ? conv_prop_invariant_wei_d(&desc_)->dims[0] : 1; }
+
+    dim_t ID() const {
+        return ndims() >= 5
+            ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1;
+    }
+    dim_t IH() const {
+        return ndims() >= 4
+            ? conv_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1;
+    }
+    dim_t IW() const {
+        return conv_prop_invariant_src_d(&desc_)->dims[ndims() - 1];
+    }
+
+    dim_t OD() const {
+        return ndims() >= 5
+            ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1;
+    }
+    dim_t OH() const {
+        return ndims() >= 4
+            ? conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1;
+    }
+    dim_t OW() const {
+        return conv_prop_invariant_dst_d(&desc_)->dims[ndims() - 1];
+    }
+
+    dim_t KD() const {
+        const int w_ndims = ndims() + with_groups();
+        return ndims() >= 5
+            ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 3] : 1;
+    }
+    dim_t KH() const {
+        const int w_ndims = ndims() + with_groups();
+        return ndims() >= 4
+            ? conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 2] : 1;
+    }
+    dim_t KW() const {
+        const int w_ndims = ndims() + with_groups();
+        return conv_prop_invariant_wei_d(&desc_)->dims[w_ndims - 1];
+    }
+
+    dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
+    dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
+    dim_t KSW() const { return desc_.strides[ndims() - 3]; }
+
+    dim_t KDD() const { return ndims() >= 5 ? desc_.dilates[ndims() - 5] : 0; }
+    dim_t KDH() const { return ndims() >= 4 ? desc_.dilates[ndims() - 4] : 1; }
+    dim_t KDW() const { return desc_.dilates[ndims() - 3]; }
+
+    dim_t padFront() const
+    { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
+    dim_t padBack() const
+    { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
+    dim_t padT() const
+    { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
+    dim_t padB() const
+    { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
+    dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
+    dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
+
+    bool with_bias() const {
+        return
+            !memory_desc_wrapper(*conv_prop_invariant_bia_d(&desc_)).is_zero();
+    }
+
+    bool with_groups() const
+    { return conv_prop_invariant_wei_d(&desc_)->ndims == ndims() + 1; }
+
+    int ndims() const { return conv_prop_invariant_src_d(&desc_)->ndims; }
+
+    bool is_fwd() const {
+        return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+                prop_kind::forward_inference);
+    }
+
+    bool has_zero_dim_memory() const {
+        const auto s_d = memory_desc_wrapper(*conv_prop_invariant_src_d(&desc_));
+        const auto d_d = memory_desc_wrapper(*conv_prop_invariant_dst_d(&desc_));
+        return s_d.has_zero_dim() || d_d.has_zero_dim();
+    }
+
+protected:
+    deconvolution_desc_t desc_;
+    const deconvolution_fwd_pd_t *hint_fwd_pd_;
+};
+
+struct deconvolution_fwd_pd_t: public deconvolution_pd_t {
+    typedef deconvolution_fwd_pd_t base_class;
+    typedef deconvolution_fwd_pd_t hint_class;
+
+    deconvolution_fwd_pd_t(engine_t *engine,
+            const deconvolution_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const deconvolution_fwd_pd_t *hint_fwd_pd)
+        : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
+        , src_md_(desc_.src_desc)
+        , weights_md_(desc_.weights_desc)
+        , bias_md_(desc_.bias_desc)
+        , dst_md_(desc_.dst_desc)
+    {}
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_BIAS && with_bias())
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_DST)
+            return arg_usage_t::output;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *src_md(int index = 0) const override
+    { return index == 0 ? &src_md_ : nullptr; }
+    virtual const memory_desc_t *dst_md(int index = 0) const override
+    { return index == 0 ? &dst_md_ : nullptr; }
+    virtual const memory_desc_t *weights_md(int index = 0) const override {
+        if (index == 0) return &weights_md_;
+        if (index == 1 && with_bias()) return &bias_md_;
+        return nullptr;
+    }
+
+    virtual int n_inputs() const override { return 2 + with_bias(); }
+    virtual int n_outputs() const override { return 1; }
+
+protected:
+    memory_desc_t src_md_;
+    memory_desc_t weights_md_;
+    memory_desc_t bias_md_;
+    memory_desc_t dst_md_;
+};
+
+struct deconvolution_bwd_data_pd_t: public deconvolution_pd_t {
+    typedef deconvolution_bwd_data_pd_t base_class;
+    typedef deconvolution_fwd_pd_t hint_class;
+
+    deconvolution_bwd_data_pd_t(engine_t *engine,
+            const deconvolution_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const deconvolution_fwd_pd_t *hint_fwd_pd)
+        : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
+        , diff_src_md_(desc_.diff_src_desc)
+        , weights_md_(desc_.weights_desc)
+        , diff_dst_md_(desc_.diff_dst_desc)
+    {}
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_DIFF_SRC)
+            return arg_usage_t::output;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *diff_src_md(int index = 0) const override
+    { return index == 0 ? &diff_src_md_ : nullptr; }
+    virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+    { return index == 0 ? &diff_dst_md_ : nullptr; }
+    virtual const memory_desc_t *weights_md(int index = 0) const override
+    { return index == 0 ? &weights_md_ : nullptr; }
+
+    virtual int n_inputs() const override { return 2; }
+    virtual int n_outputs() const override { return 1; }
+
+protected:
+    memory_desc_t diff_src_md_;
+    memory_desc_t weights_md_;
+    memory_desc_t diff_dst_md_;
+};
+
+struct deconvolution_bwd_weights_pd_t: public deconvolution_pd_t {
+    typedef deconvolution_bwd_weights_pd_t base_class;
+    typedef deconvolution_fwd_pd_t hint_class;
+
+    deconvolution_bwd_weights_pd_t(engine_t *engine,
+            const deconvolution_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const deconvolution_fwd_pd_t *hint_fwd_pd)
+        : deconvolution_pd_t(engine, adesc, attr, hint_fwd_pd)
+        , src_md_(desc_.src_desc)
+        , diff_weights_md_(desc_.diff_weights_desc)
+        , diff_bias_md_(desc_.diff_bias_desc)
+        , diff_dst_md_(desc_.diff_dst_desc)
+    {}
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
+            return arg_usage_t::output;
+
+        if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
+            return arg_usage_t::output;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *src_md(int index = 0) const override
+    { return index == 0 ? &src_md_ : nullptr; }
+    virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+    { return index == 0 ? &diff_dst_md_ : nullptr; }
+    virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
+        if (index == 0) return &diff_weights_md_;
+        if (index == 1 && with_bias()) return &diff_bias_md_;
+        return nullptr;
+    }
+
+    virtual int n_inputs() const override { return 2; }
+    virtual int n_outputs() const override { return 1 + with_bias(); }
+
+protected:
+    memory_desc_t src_md_;
+    memory_desc_t diff_weights_md_;
+    memory_desc_t diff_bias_md_;
+    memory_desc_t diff_dst_md_;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 84 - 0
thirdparty/oidn/mkl-dnn/src/common/eltwise.cpp

@@ -0,0 +1,84 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t eltwise_desc_init(eltwise_desc_t *eltwise_desc, prop_kind_t prop_kind,
+        alg_kind_t alg_kind, const memory_desc_t *data_desc,
+        const memory_desc_t *diff_data_desc, float alpha, float beta) {
+    bool args_ok = true
+        && !any_null(eltwise_desc, data_desc)
+        && one_of(prop_kind, forward_training, forward_inference,
+                backward_data)
+        && one_of(alg_kind, eltwise_relu, eltwise_tanh, eltwise_elu,
+                  eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
+                  eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic)
+        && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr);
+    if (!args_ok) return invalid_arguments;
+
+    auto ed = eltwise_desc_t();
+    ed.primitive_kind = primitive_kind::eltwise;
+    ed.prop_kind = prop_kind;
+    ed.alg_kind = alg_kind;
+
+    ed.data_desc = *data_desc;
+    ed.diff_data_desc =
+        (ed.prop_kind == backward_data) ? *diff_data_desc : zero_md();
+
+    ed.alpha = alpha;
+    ed.beta = beta;
+
+    bool consistency = true
+        && IMPLICATION(ed.prop_kind == backward_data,
+                array_cmp(ed.diff_data_desc.dims, ed.data_desc.dims,
+                    ed.diff_data_desc.ndims));
+    if (!consistency) return invalid_arguments;
+
+    *eltwise_desc = ed;
+    return success;
+}
+}
+
+status_t mkldnn_eltwise_forward_desc_init(eltwise_desc_t *eltwise_desc,
+        prop_kind_t prop_kind, alg_kind_t alg_kind,
+        const memory_desc_t *data_desc, float alpha, float beta) {
+    if (!one_of(prop_kind, forward_training, forward_inference))
+        return invalid_arguments;
+    return eltwise_desc_init(eltwise_desc, prop_kind, alg_kind, data_desc,
+            nullptr, alpha, beta);
+}
+
+status_t mkldnn_eltwise_backward_desc_init(eltwise_desc_t *eltwise_desc,
+        alg_kind_t alg_kind, const memory_desc_t *diff_data_desc,
+        const memory_desc_t *data_desc, float alpha, float beta) {
+    return eltwise_desc_init(eltwise_desc, backward_data, alg_kind, data_desc,
+            diff_data_desc, alpha, beta);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 161 - 0
thirdparty/oidn/mkl-dnn/src/common/eltwise_pd.hpp

@@ -0,0 +1,161 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef ELTWISE_PD_HPP
+#define ELTWISE_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct eltwise_fwd_pd_t;
+
+struct eltwise_pd_t: public primitive_desc_t {
+    static constexpr auto base_pkind = primitive_kind::eltwise;
+
+    eltwise_pd_t(mkldnn::impl::engine_t *engine,
+            const eltwise_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const eltwise_fwd_pd_t *hint_fwd_pd)
+        : primitive_desc_t(engine, attr, base_pkind)
+        , desc_(*adesc)
+        , hint_fwd_pd_(hint_fwd_pd)
+        , data_md_(desc_.data_desc)
+    {}
+
+    const eltwise_desc_t *desc() const { return &desc_; }
+    virtual const op_desc_t *op_desc() const override
+    { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+    virtual void init_info() override { impl::init_info(this, this->info_); }
+
+    virtual status_t query(query_t what, int idx, void *result) const override {
+        switch (what) {
+        case query::eltwise_d:
+            *(const eltwise_desc_t**)result = desc(); break;
+        default: return primitive_desc_t::query(what, idx, result);
+        }
+        return status::success;
+    }
+
+    /* common eltwise aux functions */
+
+    dim_t MB() const { return data_desc().dims[0]; }
+    dim_t C() const { return data_desc().dims[1]; }
+    dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
+    dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
+    dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
+
+    int ndims() const { return data_desc().ndims; }
+
+    bool is_fwd() const {
+        return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+                prop_kind::forward_inference);
+    }
+
+    bool has_zero_dim_memory() const
+    { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
+
+protected:
+    eltwise_desc_t desc_;
+    const eltwise_fwd_pd_t *hint_fwd_pd_;
+
+    memory_desc_t data_md_;
+
+private:
+    const memory_desc_t &data_desc() const { return desc_.data_desc; }
+};
+
+struct eltwise_fwd_pd_t: public eltwise_pd_t {
+    typedef eltwise_fwd_pd_t base_class;
+    typedef eltwise_fwd_pd_t hint_class;
+
+    eltwise_fwd_pd_t(mkldnn::impl::engine_t *engine,
+            const eltwise_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const eltwise_fwd_pd_t *hint_fwd_pd)
+        : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd)
+    {}
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (arg == MKLDNN_ARG_SRC)
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_DST)
+            return arg_usage_t::output;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *src_md(int index = 0) const override
+    { return index == 0 ? &data_md_ : nullptr; }
+    virtual const memory_desc_t *dst_md(int index = 0) const override
+    { return index == 0 ? &data_md_ : nullptr; }
+
+    virtual int n_inputs() const override { return 1; }
+    virtual int n_outputs() const override { return 1; }
+
+    bool is_zero_preserved() const
+    { return math::eltwise_fwd_preserves_zero(desc_.alg_kind); }
+};
+
+struct eltwise_bwd_pd_t: public eltwise_pd_t {
+    typedef eltwise_bwd_pd_t base_class;
+    typedef eltwise_fwd_pd_t hint_class;
+
+    eltwise_bwd_pd_t(engine_t *engine,
+            const eltwise_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const eltwise_fwd_pd_t *hint_fwd_pd)
+        : eltwise_pd_t(engine, adesc, attr, hint_fwd_pd)
+        , diff_data_md_(desc_.diff_data_desc)
+    {}
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_DIFF_SRC)
+            return arg_usage_t::output;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *src_md(int index = 0) const override
+    { return index == 0 ? &data_md_ : nullptr; }
+    virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+    { return index == 0 ? &diff_data_md_ : nullptr; }
+    virtual const memory_desc_t *diff_src_md(int index = 0) const override
+    { return index == 0 ? &diff_data_md_ : nullptr; }
+
+    virtual int n_inputs() const override { return 2; }
+    virtual int n_outputs() const override { return 1; }
+
+    bool is_zero_preserved() const { return true; }
+
+protected:
+    memory_desc_t diff_data_md_;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 75 - 0
thirdparty/oidn/mkl-dnn/src/common/engine.cpp

@@ -0,0 +1,75 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn.h"
+#include "engine.hpp"
+#include "nstl.hpp"
+
+#include "c_types_map.hpp"
+#include "../cpu/cpu_engine.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+engine_factory_t *engine_factories[] = {
+    &cpu::engine_factory,
+    nullptr,
+};
+
+static inline engine_factory_t *get_engine_factory(engine_kind_t kind) {
+    for (engine_factory_t **ef = engine_factories; *ef; ef++)
+        if ((*ef)->kind() == kind)
+            return *ef;
+    return nullptr;
+}
+
+}
+}
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+
+size_t mkldnn_engine_get_count(engine_kind_t kind) {
+    engine_factory_t *ef = get_engine_factory(kind);
+    return ef != nullptr ? ef->count() : 0;
+}
+
+status_t mkldnn_engine_create(engine_t **engine,
+        engine_kind_t kind, size_t index) {
+    if (engine == nullptr)
+        return invalid_arguments;
+
+    engine_factory_t *ef = get_engine_factory(kind);
+    if (ef == nullptr || index >= ef->count())
+        return invalid_arguments;
+
+    return ef->engine_create(engine, index);
+}
+
+status_t mkldnn_engine_get_kind(engine_t *engine, engine_kind_t *kind) {
+    if (engine == nullptr)
+        return invalid_arguments;
+    *kind = engine->kind();
+    return success;
+}
+
+status_t mkldnn_engine_destroy(engine_t *engine) {
+    /* TODO: engine->dec_ref_count(); */
+    delete engine;
+    return success;
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 119 - 0
thirdparty/oidn/mkl-dnn/src/common/engine.hpp

@@ -0,0 +1,119 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef ENGINE_HPP
+#define ENGINE_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive.hpp"
+#include "utils.hpp"
+
+/** \brief An abstraction of an execution unit with shared resources
+ *
+ * Responsibilities:
+ *   - Provide engine specific memory allocation
+ *   - Provide engine specific primitive_desc_t creators
+ */
+struct mkldnn_engine: public mkldnn::impl::c_compatible {
+    mkldnn_engine(mkldnn::impl::engine_kind_t kind)
+        : kind_(kind)
+    {}
+    virtual ~mkldnn_engine() {}
+
+    /** get kind of the current engine */
+    virtual mkldnn::impl::engine_kind_t kind() const { return kind_; }
+
+    /** allocate memory */
+    virtual mkldnn::impl::status_t memory_create(
+            mkldnn::impl::memory_t **memory,
+            const mkldnn::impl::memory_desc_t *md,
+            void *handle) = 0;
+
+    /** implementation section (typedefs) */
+
+    // TODO: remove engine?
+    typedef mkldnn::impl::status_t (*reorder_primitive_desc_create_f)(
+            mkldnn::impl::reorder_pd_t **reorder_pd,
+            mkldnn::impl::engine_t *engine,
+            const mkldnn::impl::primitive_attr_t *attr,
+            mkldnn::impl::engine_t *src_engine,
+            const mkldnn::impl::memory_desc_t *src_md,
+            mkldnn::impl::engine_t *dst_engine,
+            const mkldnn::impl::memory_desc_t *dst_md);
+
+    typedef mkldnn::impl::status_t (*concat_primitive_desc_create_f)(
+            mkldnn::impl::concat_pd_t **concat_pd,
+            mkldnn::impl::engine_t *engine,
+            const mkldnn::impl::primitive_attr_t *attr,
+            const mkldnn::impl::memory_desc_t *dst_md,
+            int n, int concat_dim,
+            const mkldnn::impl::memory_desc_t *src_mds);
+
+    typedef mkldnn::impl::status_t (*sum_primitive_desc_create_f)(
+            mkldnn::impl::sum_pd_t **sum_pd,
+            mkldnn::impl::engine_t *engine,
+            const mkldnn::impl::primitive_attr_t *attr,
+            const mkldnn::impl::memory_desc_t *dst_md,
+            int n, const float *scales,
+            const mkldnn::impl::memory_desc_t *src_mds);
+
+    typedef mkldnn::impl::status_t (*primitive_desc_create_f)(
+            mkldnn::impl::primitive_desc_t **, const mkldnn::impl::op_desc_t *,
+            const mkldnn::impl::primitive_attr_t *attr,
+            mkldnn::impl::engine_t *, const mkldnn::impl::primitive_desc_t *);
+
+    /* implementation section */
+
+    /** return the list of reorder implementations. engine guarantees to return
+     * a NULL-terminated list */
+    virtual const reorder_primitive_desc_create_f*
+        get_reorder_implementation_list() const = 0;
+
+    /** return the list of concat implementations. engine guarantees to return
+     * a NULL-terminated list */
+    virtual const concat_primitive_desc_create_f*
+        get_concat_implementation_list() const = 0;
+
+    /** return the list of sum implementations. engine guarantees to return
+     * a NULL-terminated list */
+    virtual const sum_primitive_desc_create_f*
+        get_sum_implementation_list() const = 0;
+
+    /** return the list of implementations. engine guarantees to return a
+     * NULL-terminated list */
+    virtual const primitive_desc_create_f* get_implementation_list() const = 0;
+
+protected:
+    mkldnn::impl::engine_kind_t kind_;
+};
+
+namespace mkldnn {
+namespace impl {
+
+struct engine_factory_t: public c_compatible {
+    virtual size_t count() const = 0;
+    virtual engine_kind_t kind() const = 0;
+    virtual status_t engine_create(engine_t **engine, size_t index) const = 0;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 106 - 0
thirdparty/oidn/mkl-dnn/src/common/inner_product.cpp

@@ -0,0 +1,106 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t ip_desc_init(inner_product_desc_t *ip_desc, prop_kind_t prop_kind,
+        const memory_desc_t *src_desc, const memory_desc_t *weights_desc,
+        const memory_desc_t *bias_desc, const memory_desc_t *dst_desc) {
+    bool args_ok = !any_null(ip_desc, src_desc, weights_desc, dst_desc);
+    if (!args_ok) return invalid_arguments;
+
+    auto id = inner_product_desc_t();
+    id.primitive_kind = primitive_kind::inner_product;
+    id.prop_kind = prop_kind;
+
+    id.diff_src_desc = id.src_desc = zero_md();
+    id.diff_dst_desc = id.dst_desc = zero_md();
+    id.diff_weights_desc = id.weights_desc = zero_md();
+    id.diff_bias_desc = id.bias_desc = zero_md();
+
+    const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
+    const bool with_bias =
+        bias_desc && bias_desc->format_kind != format_kind::undef;
+
+    (prop_kind == backward_data ? id.diff_src_desc : id.src_desc) = *src_desc;
+    (is_fwd ? id.dst_desc : id.diff_dst_desc)  = *dst_desc;
+    (prop_kind == backward_weights ? id.diff_weights_desc : id.weights_desc) =
+        *weights_desc;
+    if (with_bias)
+        (prop_kind == backward_weights ? id.diff_bias_desc : id.bias_desc) =
+            *bias_desc;
+
+    id.accum_data_type = types::default_accum_data_type(src_desc->data_type,
+            weights_desc->data_type, dst_desc->data_type, prop_kind);
+
+    bool consistency = true
+        && memory_desc_wrapper(weights_desc).nelems()
+        && one_of(src_desc->ndims, 2, 3, 4, 5)
+        && dst_desc->ndims == 2
+        && weights_desc->ndims == src_desc->ndims
+        && (with_bias ? bias_desc->ndims == 1 : true)
+        && (with_bias ? bias_desc->dims[0] == dst_desc->dims[1] : true)
+        && src_desc->dims[0] == dst_desc->dims[0]
+        && array_cmp(&src_desc->dims[1], &weights_desc->dims[1],
+                src_desc->ndims - 1)
+        && dst_desc->dims[1] == weights_desc->dims[0];
+    if (!consistency) return invalid_arguments;
+
+    *ip_desc = id;
+    return success;
+}
+}
+
+status_t mkldnn_inner_product_forward_desc_init(inner_product_desc_t *ip_desc,
+        prop_kind_t prop_kind, const memory_desc_t *src_desc,
+        const memory_desc_t *weights_desc, const memory_desc_t *bias_desc,
+        const memory_desc_t *dst_desc) {
+    if (!one_of(prop_kind, forward_training, forward_inference))
+        return invalid_arguments;
+    return ip_desc_init(ip_desc, prop_kind, src_desc, weights_desc, bias_desc,
+            dst_desc);
+}
+
+status_t mkldnn_inner_product_backward_data_desc_init(
+        inner_product_desc_t *ip_desc, const memory_desc_t *diff_src_desc,
+        const memory_desc_t *weights_desc, const memory_desc_t *diff_dst_desc)
+{
+    return ip_desc_init(ip_desc, backward_data, diff_src_desc, weights_desc,
+            nullptr, diff_dst_desc);
+}
+
+status_t mkldnn_inner_product_backward_weights_desc_init(
+        inner_product_desc_t *ip_desc, const memory_desc_t *src_desc,
+        const memory_desc_t *diff_weights_desc,
+        const memory_desc_t *diff_bias_desc,
+        const memory_desc_t *diff_dst_desc) {
+    return ip_desc_init(ip_desc, backward_weights, src_desc, diff_weights_desc,
+            diff_bias_desc, diff_dst_desc);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 56 - 0
thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.cpp

@@ -0,0 +1,56 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "utils.hpp"
+
+#include "inner_product_pd.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+using namespace prop_kind;
+
+memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc) {
+    return desc->prop_kind == backward_data
+        ? &desc->diff_src_desc : &desc->src_desc;
+}
+
+memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc) {
+    return desc->prop_kind == backward_weights
+        ? &desc->diff_weights_desc : &desc->weights_desc;
+}
+
+memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc) {
+    return desc->prop_kind == backward_weights
+        ? &desc->diff_bias_desc : &desc->bias_desc;
+}
+
+memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc) {
+    return utils::one_of(desc->prop_kind, forward_inference, forward_training)
+        ? &desc->dst_desc : &desc->diff_dst_desc;
+}
+
+const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc)
+{ return ip_prop_invariant_src_d(const_cast<inner_product_desc_t *>(desc)); }
+const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc)
+{ return ip_prop_invariant_wei_d(const_cast<inner_product_desc_t *>(desc)); }
+const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc)
+{ return ip_prop_invariant_bia_d(const_cast<inner_product_desc_t *>(desc)); }
+const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc)
+{ return ip_prop_invariant_dst_d(const_cast<inner_product_desc_t *>(desc)); }
+
+}
+}

+ 321 - 0
thirdparty/oidn/mkl-dnn/src/common/inner_product_pd.hpp

@@ -0,0 +1,321 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef INNER_PRODUCT_PD_HPP
+#define INNER_PRODUCT_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+memory_desc_t *ip_prop_invariant_src_d(inner_product_desc_t *desc);
+memory_desc_t *ip_prop_invariant_wei_d(inner_product_desc_t *desc);
+memory_desc_t *ip_prop_invariant_bia_d(inner_product_desc_t *desc);
+memory_desc_t *ip_prop_invariant_dst_d(inner_product_desc_t *desc);
+const memory_desc_t *ip_prop_invariant_src_d(const inner_product_desc_t *desc);
+const memory_desc_t *ip_prop_invariant_wei_d(const inner_product_desc_t *desc);
+const memory_desc_t *ip_prop_invariant_bia_d(const inner_product_desc_t *desc);
+const memory_desc_t *ip_prop_invariant_dst_d(const inner_product_desc_t *desc);
+
+struct inner_product_fwd_pd_t;
+
+struct inner_product_pd_t: public primitive_desc_t {
+    static constexpr auto base_pkind = primitive_kind::inner_product;
+
+    inner_product_pd_t(engine_t *engine,
+            const inner_product_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const inner_product_fwd_pd_t *hint_fwd_pd)
+        : primitive_desc_t(engine, attr, base_pkind)
+        , desc_(*adesc)
+        , hint_fwd_pd_(hint_fwd_pd)
+    {}
+
+    const inner_product_desc_t *desc() const { return &desc_; }
+    virtual const op_desc_t *op_desc() const override
+    { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+    virtual void init_info() override { impl::init_info(this, this->info_); }
+
+    virtual status_t query(query_t what, int idx, void *result) const override {
+        switch (what) {
+        case query::inner_product_d:
+            *(const inner_product_desc_t**)result = desc(); break;
+        default: return primitive_desc_t::query(what, idx, result);
+        }
+        return status::success;
+    }
+
+    /* common inner_product aux functions */
+
+    dim_t MB() const { return ip_prop_invariant_src_d(&desc_)->dims[0]; }
+    dim_t IC() const { return ip_prop_invariant_src_d(&desc_)->dims[1]; }
+    dim_t OC() const { return ip_prop_invariant_dst_d(&desc_)->dims[1]; }
+
+    dim_t ID() const {
+        return ndims() >= 5
+            ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 3] : 1;
+    }
+    dim_t IH() const {
+        return ndims() >= 4
+            ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 2] : 1;
+    }
+    dim_t IW() const {
+        return ndims() >= 3
+            ? ip_prop_invariant_src_d(&desc_)->dims[ndims() - 1] : 1;
+    }
+
+    dim_t OD() const {
+        return ndims() >= 5
+            ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 3] : 1;
+    }
+    dim_t OH() const {
+        return ndims() >= 4
+            ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 2] : 1;
+    }
+    dim_t OW() const {
+        return ndims() >= 3
+            ? ip_prop_invariant_dst_d(&desc_)->dims[ndims() - 1] : 1;
+    }
+
+    dim_t KD() const {
+        return ndims() >= 5
+            ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 3] : 1;
+    }
+    dim_t KH() const {
+        return ndims() >= 4
+            ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 2] : 1;
+    }
+    dim_t KW() const {
+        return ndims() >= 3
+            ? ip_prop_invariant_wei_d(&desc_)->dims[ndims() - 1] : 1;
+    }
+
+    dim_t IC_total() const {
+        return utils::array_product(&ip_prop_invariant_src_d(&desc_)->dims[1],
+                ndims() - 1);
+    }
+
+    dim_t IC_total_padded() const {
+        auto src_d = desc()->prop_kind == prop_kind::backward_data
+            ? memory_desc_wrapper(diff_src_md())
+            : memory_desc_wrapper(src_md());
+        assert(src_d.is_blocking_desc());
+        if (!src_d.is_blocking_desc()) return -1;
+        return utils::array_product(src_d.padded_dims() + 1, ndims() - 1);
+    }
+
+    int ndims() const { return ip_prop_invariant_src_d(&desc_)->ndims; }
+
+    bool with_bias() const
+    { return !memory_desc_wrapper(*ip_prop_invariant_bia_d(&desc_)).is_zero(); }
+
+    bool has_zero_dim_memory() const {
+        const auto s_d = memory_desc_wrapper(*ip_prop_invariant_src_d(&desc_));
+        const auto d_d = memory_desc_wrapper(*ip_prop_invariant_dst_d(&desc_));
+        return s_d.has_zero_dim() || d_d.has_zero_dim();
+    }
+
+    bool is_fwd() const {
+        return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+                prop_kind::forward_inference);
+    }
+
+protected:
+    inner_product_desc_t desc_;
+    const inner_product_fwd_pd_t *hint_fwd_pd_;
+
+    status_t template_set_default_params(memory_desc_t &src_md,
+            memory_desc_t &weights_md, memory_desc_t &dst_md,
+            memory_desc_t *bias_md) {
+        using namespace format_tag;
+        if (src_md.format_kind == format_kind::any) {
+            CHECK(memory_desc_init_by_tag(src_md,
+                        utils::pick(ndims() - 2, nc, ncw, nchw, ncdhw)));
+        }
+        if (dst_md.format_kind == format_kind::any)
+            CHECK(memory_desc_init_by_tag(dst_md, nc));
+        if (weights_md.format_kind == format_kind::any) {
+            CHECK(memory_desc_init_by_tag(weights_md,
+                        utils::pick(ndims() - 2, oi, oiw, oihw, oidhw)));
+        }
+        if (bias_md && bias_md->format_kind == format_kind::any)
+            CHECK(memory_desc_init_by_tag(*bias_md, x));
+        return status::success;
+    }
+};
+
+struct inner_product_fwd_pd_t: public inner_product_pd_t {
+    typedef inner_product_fwd_pd_t base_class;
+    typedef inner_product_fwd_pd_t hint_class;
+
+    inner_product_fwd_pd_t(engine_t *engine,
+            const inner_product_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const inner_product_fwd_pd_t *hint_fwd_pd)
+        : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
+        , src_md_(desc_.src_desc)
+        , weights_md_(desc_.weights_desc)
+        , bias_md_(desc_.bias_desc)
+        , dst_md_(desc_.dst_desc)
+    {}
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_WEIGHTS))
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_BIAS && with_bias())
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_DST)
+            return arg_usage_t::output;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *src_md(int index = 0) const override
+    { return index == 0 ? &src_md_ : nullptr; }
+    virtual const memory_desc_t *dst_md(int index = 0) const override
+    { return index == 0 ? &dst_md_ : nullptr; }
+    virtual const memory_desc_t *weights_md(int index = 0) const override {
+        if (index == 0) return &weights_md_;
+        if (index == 1 && with_bias()) return &bias_md_;
+        return nullptr;
+    }
+
+    virtual int n_inputs() const override { return 2 + with_bias(); }
+    virtual int n_outputs() const override { return 1; }
+
+protected:
+    memory_desc_t src_md_;
+    memory_desc_t weights_md_;
+    memory_desc_t bias_md_;
+    memory_desc_t dst_md_;
+
+    status_t set_default_params() {
+        return template_set_default_params(src_md_, weights_md_, dst_md_,
+                &bias_md_);
+    }
+};
+
+struct inner_product_bwd_data_pd_t: public inner_product_pd_t {
+    typedef inner_product_bwd_data_pd_t base_class;
+    typedef inner_product_fwd_pd_t hint_class;
+
+    inner_product_bwd_data_pd_t(engine_t *engine,
+            const inner_product_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const inner_product_fwd_pd_t *hint_fwd_pd)
+        : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
+        , diff_src_md_(desc_.diff_src_desc)
+        , weights_md_(desc_.weights_desc)
+        , diff_dst_md_(desc_.diff_dst_desc)
+    {}
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS, MKLDNN_ARG_DIFF_DST))
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_DIFF_SRC)
+            return arg_usage_t::output;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *diff_src_md(int index = 0) const override
+    { return index == 0 ? &diff_src_md_ : nullptr; }
+    virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+    { return index == 0 ? &diff_dst_md_ : nullptr; }
+    virtual const memory_desc_t *weights_md(int index = 0) const override
+    { return index == 0 ? &weights_md_ : nullptr; }
+
+    virtual int n_inputs() const override { return 2; }
+    virtual int n_outputs() const override { return 1; }
+
+protected:
+    memory_desc_t diff_src_md_;
+    memory_desc_t weights_md_;
+    memory_desc_t diff_dst_md_;
+
+    status_t set_default_params() {
+        return template_set_default_params(diff_src_md_, weights_md_,
+                diff_dst_md_, nullptr);
+    }
+};
+
+struct inner_product_bwd_weights_pd_t: public inner_product_pd_t {
+    typedef inner_product_bwd_weights_pd_t base_class;
+    typedef inner_product_fwd_pd_t hint_class;
+
+    inner_product_bwd_weights_pd_t(engine_t *engine,
+            const inner_product_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const inner_product_fwd_pd_t *hint_fwd_pd)
+        : inner_product_pd_t(engine, adesc, attr, hint_fwd_pd)
+        , src_md_(desc_.src_desc)
+        , diff_weights_md_(desc_.diff_weights_desc)
+        , diff_bias_md_(desc_.diff_bias_desc)
+        , diff_dst_md_(desc_.diff_dst_desc)
+    {}
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_DIFF_WEIGHTS)
+            return arg_usage_t::output;
+
+        if (arg == MKLDNN_ARG_DIFF_BIAS && with_bias())
+            return arg_usage_t::output;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *src_md(int index = 0) const override
+    { return index == 0 ? &src_md_ : nullptr; }
+    virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+    { return index == 0 ? &diff_dst_md_ : nullptr; }
+    virtual const memory_desc_t *diff_weights_md(int index = 0) const override {
+        if (index == 0) return &diff_weights_md_;
+        if (index == 1 && with_bias()) return &diff_bias_md_;
+        return nullptr;
+    }
+
+    virtual int n_inputs() const override { return 2; }
+    virtual int n_outputs() const override { return 1 + with_bias(); }
+
+protected:
+    memory_desc_t src_md_;
+    memory_desc_t diff_weights_md_;
+    memory_desc_t diff_bias_md_;
+    memory_desc_t diff_dst_md_;
+
+    status_t set_default_params() {
+        return template_set_default_params(src_md_, diff_weights_md_,
+                diff_dst_md_, &diff_bias_md_);
+    }
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 91 - 0
thirdparty/oidn/mkl-dnn/src/common/lrn.cpp

@@ -0,0 +1,91 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t lrn_desc_init(lrn_desc_t *lrn_desc,
+        prop_kind_t prop_kind, alg_kind_t alg_kind,
+        const memory_desc_t *data_desc, const memory_desc_t *diff_data_desc,
+        dim_t local_size, float alpha, float beta, float k) {
+    bool args_ok = true
+        && !any_null(lrn_desc, data_desc)
+        && one_of(alg_kind, lrn_within_channel, lrn_across_channels)
+        && one_of(prop_kind, forward_training, forward_inference, backward_data)
+        && IMPLICATION(prop_kind == backward_data, diff_data_desc != nullptr);
+    if (!args_ok) return invalid_arguments;
+
+    auto ld = lrn_desc_t();
+    ld.primitive_kind = primitive_kind::lrn;
+    ld.prop_kind = prop_kind;
+    ld.alg_kind = alg_kind;
+
+    const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
+
+    ld.data_desc = *data_desc;
+    if (!is_fwd)
+        ld.diff_data_desc = *diff_data_desc;
+    else
+        ld.diff_data_desc = zero_md();
+    ld.local_size = local_size;
+    ld.lrn_alpha = alpha;
+    ld.lrn_beta = beta;
+    ld.lrn_k = k;
+
+    bool consistency = true
+        && ld.data_desc.ndims == 4;
+    if (ld.prop_kind == backward_data)
+        consistency = consistency
+            && ld.diff_data_desc.ndims == 4
+            && array_cmp(ld.diff_data_desc.dims, ld.data_desc.dims, 4);
+    if (!consistency) return invalid_arguments;
+
+    *lrn_desc = ld;
+    return success;
+}
+}
+
+status_t mkldnn_lrn_forward_desc_init(lrn_desc_t *lrn_desc,
+        prop_kind_t prop_kind, alg_kind_t alg_kind,
+        const memory_desc_t *data_desc, dim_t local_size, float alpha,
+        float beta, float k) {
+    if (!one_of(prop_kind, forward_training, forward_inference))
+        return invalid_arguments;
+    return lrn_desc_init(lrn_desc, prop_kind, alg_kind, data_desc, nullptr,
+            local_size, alpha, beta, k);
+}
+
+status_t mkldnn_lrn_backward_desc_init(lrn_desc_t *lrn_desc,
+        alg_kind_t alg_kind, const memory_desc_t *data_desc,
+        const memory_desc_t *diff_data_desc, dim_t local_size, float alpha,
+        float beta, float k) {
+    return lrn_desc_init(lrn_desc, backward_data, alg_kind, data_desc,
+            diff_data_desc, local_size, alpha, beta, k);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 170 - 0
thirdparty/oidn/mkl-dnn/src/common/lrn_pd.hpp

@@ -0,0 +1,170 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef LRN_PD_HPP
+#define LRN_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct lrn_fwd_pd_t;
+
+struct lrn_pd_t: public primitive_desc_t {
+    static constexpr auto base_pkind = primitive_kind::lrn;
+
+    lrn_pd_t(engine_t *engine,
+            const lrn_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const lrn_fwd_pd_t *hint_fwd_pd)
+        : primitive_desc_t(engine, attr, base_pkind)
+        , desc_(*adesc)
+        , hint_fwd_pd_(hint_fwd_pd)
+        , data_md_(desc_.data_desc)
+        , ws_md_()
+    {}
+
+    const lrn_desc_t *desc() const { return &desc_; }
+    virtual const op_desc_t *op_desc() const override
+    { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+    virtual void init_info() override { impl::init_info(this, this->info_); }
+
+    virtual status_t query(query_t what, int idx, void *result) const override {
+        switch (what) {
+        case query::lrn_d:
+            *(const lrn_desc_t**)result = desc(); break;
+        default: return primitive_desc_t::query(what, idx, result);
+        }
+        return status::success;
+    }
+
+    /* common lrn aux functions */
+
+    dim_t MB() const { return data_desc().dims[0]; }
+    dim_t C() const { return data_desc().dims[1]; }
+    dim_t D() const { return ndims() >= 5 ? data_desc().dims[ndims() - 3] : 1; }
+    dim_t H() const { return ndims() >= 4 ? data_desc().dims[ndims() - 2] : 1; }
+    dim_t W() const { return ndims() >= 3 ? data_desc().dims[ndims() - 1] : 1; }
+
+    int ndims() const { return data_desc().ndims; }
+
+    bool has_zero_dim_memory() const
+    { return memory_desc_wrapper(desc_.data_desc).has_zero_dim(); }
+
+    bool is_fwd() const {
+        return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+                prop_kind::forward_inference);
+    }
+
+protected:
+    lrn_desc_t desc_;
+    const lrn_fwd_pd_t *hint_fwd_pd_;
+
+    memory_desc_t data_md_;
+    memory_desc_t ws_md_;
+
+private:
+    const memory_desc_t &data_desc() const { return desc_.data_desc; }
+};
+
+struct lrn_fwd_pd_t: public lrn_pd_t {
+    typedef lrn_fwd_pd_t base_class;
+    typedef lrn_fwd_pd_t hint_class;
+
+    lrn_fwd_pd_t(engine_t *engine,
+            const lrn_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const lrn_fwd_pd_t *hint_fwd_pd)
+        : lrn_pd_t(engine, adesc, attr, hint_fwd_pd)
+    {}
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (arg == MKLDNN_ARG_SRC)
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_DST)
+            return arg_usage_t::output;
+
+        if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
+            return arg_usage_t::output;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *src_md(int index = 0) const override
+    { return index == 0 ? &data_md_ : nullptr; }
+    virtual const memory_desc_t *dst_md(int index = 0) const override
+    { return index == 0 ? &data_md_ : nullptr; }
+    virtual const memory_desc_t *workspace_md(int index = 0) const override
+    { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
+
+    virtual int n_inputs() const override { return 1; }
+    virtual int n_outputs() const override
+    { return 1 + (workspace_md() != nullptr); }
+};
+
+struct lrn_bwd_pd_t: public lrn_pd_t {
+    typedef lrn_bwd_pd_t base_class;
+    typedef lrn_fwd_pd_t hint_class;
+
+    lrn_bwd_pd_t(engine_t *engine,
+            const lrn_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const lrn_fwd_pd_t *hint_fwd_pd)
+        : lrn_pd_t(engine, adesc, attr, hint_fwd_pd)
+        , diff_data_md_(desc_.diff_data_desc)
+    {}
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (utils::one_of(arg, MKLDNN_ARG_SRC, MKLDNN_ARG_DIFF_DST))
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_DIFF_SRC)
+            return arg_usage_t::output;
+
+        if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
+            return arg_usage_t::input;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *src_md(int index = 0) const override
+    { return index == 0 ? &data_md_ : nullptr; }
+    virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+    { return index == 0 ? &diff_data_md_ : nullptr; }
+    virtual const memory_desc_t *diff_src_md(int index = 0) const override
+    { return index == 0 ? &diff_data_md_ : nullptr; }
+    virtual const memory_desc_t *workspace_md(int index = 0) const override
+    { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
+
+    virtual int n_inputs() const override
+    { return 2 + (workspace_md() != nullptr); }
+    virtual int n_outputs() const override { return 1; }
+
+protected:
+    memory_desc_t diff_data_md_;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 280 - 0
thirdparty/oidn/mkl-dnn/src/common/math_utils.hpp

@@ -0,0 +1,280 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MATH_UTILS_HPP
+#define MATH_UTILS_HPP
+
+#include <stdint.h>
+#include <math.h>
+
+#include "utils.hpp"
+#include "nstl.hpp"
+#include "mkldnn_traits.hpp"
+
+#if defined(MKLDNN_X86_64)
+#include "immintrin.h"
+#endif
+
+namespace mkldnn {
+namespace impl {
+namespace math {
+
+/** rounds @p f to an integer according to the mxcsr register */
+inline int mxcsr_round(float f) {
+#if defined(MKLDNN_X86_64)
+    return _mm_cvtss_si32(_mm_load_ss(&f));
+#else
+    return (int)nearbyintf(f); // optimism
+#endif
+}
+
+template <typename data_t, typename acc_t>
+inline typename utils::enable_if<!nstl::is_integral<data_t>::value,
+       typename utils::remove_reference<data_t>::type>::type
+saturate(const acc_t &x) {
+    return (typename utils::remove_reference<data_t>::type)x;
+}
+
+template <typename data_t, typename acc_t>
+inline typename utils::enable_if<nstl::is_integral<data_t>::value,
+       typename utils::remove_reference<data_t>::type>::type
+saturate(const acc_t &x) {
+    acc_t v = x;
+    if (v < (acc_t)nstl::numeric_limits<data_t>::lowest())
+        v = (acc_t)nstl::numeric_limits<data_t>::lowest();
+    if (v > (acc_t)nstl::numeric_limits<data_t>::max())
+        v = (acc_t)nstl::numeric_limits<data_t>::max();
+    return (typename utils::remove_reference<data_t>::type)v;
+}
+
+template <typename data_t>
+double saturate(const double &x) {
+    double v = x;
+    if (v < (double)nstl::numeric_limits<data_t>::lowest())
+        v = (double)nstl::numeric_limits<data_t>::lowest();
+    if (v > (double)nstl::numeric_limits<data_t>::max())
+        v = (double)nstl::numeric_limits<data_t>::max();
+    return v;
+}
+
+template <> inline int8_t saturate<int8_t, uint8_t>(const uint8_t &x) {
+    return x <= 127u ? x : 127;
+}
+
+template <> inline uint8_t saturate<uint8_t, int8_t>(const int8_t &x) {
+    return x >= 0 ? x : 0;
+}
+
+template <typename out_t>
+typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
+out_round(float v) { return (out_t)mxcsr_round(v); }
+
+template <typename out_t>
+typename utils::enable_if<nstl::is_integral<out_t>::value, out_t>::type
+out_round(double v) { return (out_t)mxcsr_round((float)v); }
+
+template <typename out_t>
+typename utils::enable_if<!nstl::is_integral<out_t>::value, out_t>::type
+out_round(float v) { return v; }
+
+inline int gcd(int a, int b) {
+    a = impl::nstl::abs(a);
+    b = impl::nstl::abs(b);
+    if (a < b) { int x = a; a = b; b = x; }
+
+    if (b == 0) return a;
+
+    int r;
+    while ((r = a % b) != 0) { a = b; b = r; }
+
+    return b;
+}
+
+template <typename T>
+inline bool is_pow2(const T& v) { return (v & (v - 1)) == 0; }
+
+/** returns floor(log2(v)), aka the position of the leftmost non-0 bit */
+inline int ilog2q(size_t v) {
+    if (v == 0)
+        return -1;
+
+    int p = 0;
+#   define CP(pw) do { if (v >= (1ull << pw)) { v >>= pw; p += pw; } } while(0)
+    CP(32); CP(16); CP(8); CP(4); CP(2); CP(1);
+#   undef CP
+    return p;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U one_m_square(T x) {
+    return (U)(1 - x) * (1 + x);
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U x_m_square(T x) {
+    return (U)(1 - x) * x;
+}
+
+/* activation */
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+inline U relu_fwd(T s, A alpha) {
+    return s > 0 ? s : (U)(s * alpha);
+}
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+inline U relu_bwd(T dd, T s, A alpha) {
+    return s > 0 ? dd : (U)(dd * alpha);
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U tanh_fwd(T s) {
+    const float e = tanhf((float) s);
+    return (U)e;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U tanh_bwd(T dd, T s) {
+    const float e = tanh_fwd<float>((float) s);
+    return (U)(dd * (1 - e) * (1 + e));
+}
+
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+inline U elu_fwd(T s, A alpha) {
+    return s > 0 ? s : (U)(alpha * (::expm1f((float)s)));
+}
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+ inline U elu_bwd(T dd, T s, A alpha) {
+    return (U)(dd * (s > 0 ? 1 : alpha * ::expf((float)s)));
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U square_fwd(T s) {
+    return s * s;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U square_bwd(T dd, T s) {
+    return dd * 2 * s;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U abs_fwd(T s) {
+    return s > 0 ? s : -s;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U abs_bwd(T dd, T s) {
+    return s > 0 ? dd : s < 0 ? -dd : 0;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U sqrt_fwd(T s) {
+    return s > 0 ? (U)(::sqrtf((float)(s))) : 0;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U sqrt_bwd(T dd, T s) {
+    return s > 0
+        ? (U)(dd / (2 * ::sqrtf((float)(s))))
+        : 0;
+}
+
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+inline U linear_fwd(T s, A alpha, A beta) {
+    return (U)(alpha * s + beta);
+}
+
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+inline U linear_bwd(T dd, T s, A alpha, A beta) {
+    (void) s;
+    (void) beta;
+    return (U)(dd * alpha);
+}
+
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+inline U bounded_relu_fwd(T s, A alpha) {
+    s = s > 0 ? s : 0;
+    return s > alpha ? (U)(alpha) : s;
+}
+
+template <typename T, typename A,
+         typename U = typename utils::remove_reference<T>::type>
+inline U bounded_relu_bwd(T dd, T s, A alpha) {
+    return dd * (0 < s && s < alpha ? 1 : 0);
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U soft_relu_fwd(T s) {
+    float max_logf = 8.872284e+01; //::logf(FLT_MAX)
+    return s < max_logf ? (U)(::log1pf(::expf((float)s))) : s;
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U soft_relu_bwd(T dd, T s) {
+    return (U)(dd / (1 + ::expf((float)(-s))));
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U logistic_fwd(T s) {
+    U v = (U)(::expf((float) -s));
+    return 1 / (1 + v);
+}
+
+template <typename T, typename U = typename utils::remove_reference<T>::type>
+inline U logistic_bwd(T dd, T s) {
+    U v = logistic_fwd<T, U>(s);
+    return dd * v * (1 - v);
+}
+
+inline bool eltwise_fwd_preserves_zero(alg_kind_t alg, bool jit_impl = false) {
+    using namespace alg_kind;
+    using namespace utils;
+    const bool preserves_zero = true
+        && !one_of(alg, eltwise_linear, eltwise_soft_relu, eltwise_logistic)
+        && IMPLICATION(jit_impl, !one_of(alg, eltwise_elu, eltwise_tanh));
+    return preserves_zero;
+}
+
+inline float get_bias(const char *bias, size_t offset, data_type_t data_type)
+{
+    if (!bias)
+        return 0.0f;
+
+#define CASE(dt) \
+    case dt: return (float)((const prec_traits<dt>::type *)bias)[offset]
+
+    switch (data_type) {
+    CASE(data_type::s8);
+    CASE(data_type::u8);
+    CASE(data_type::s32);
+    CASE(data_type::f32);
+    default: assert(!"unimplemented");
+    }
+    return 0; // never happens (should probably be a NaN)
+#undef CASE
+}
+
+}
+}
+}
+
+#endif

+ 238 - 0
thirdparty/oidn/mkl-dnn/src/common/memory.cpp

@@ -0,0 +1,238 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include <stddef.h>
+#include <stdint.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::data_type;
+
+namespace {
+bool memory_desc_sanity_check(int ndims,const dims_t dims,
+        data_type_t data_type, format_kind_t format_kind) {
+    if (ndims == 0) return true;
+
+    bool ok = true
+        && dims != nullptr
+        && 0 < ndims && ndims <= MKLDNN_MAX_NDIMS
+        && one_of(data_type, f32, s32, s8, u8)
+        && format_kind != format_kind::undef;
+    if (!ok) return false;
+    for (int d = 0; d < ndims; ++d)
+        if (dims[d] < 0) return false;
+
+    return true;
+}
+
+bool memory_desc_sanity_check(const memory_desc_t *md) {
+    if (md == nullptr) return false;
+    return memory_desc_sanity_check(md->ndims, md->dims, md->data_type,
+            format_kind::any);
+}
+}
+
+status_t mkldnn_memory_desc_init_by_tag(memory_desc_t *memory_desc, int ndims,
+        const dims_t dims, data_type_t data_type, format_tag_t tag) {
+    if (any_null(memory_desc)) return invalid_arguments;
+    if (ndims == 0 || tag == format_tag::undef) {
+        *memory_desc = types::zero_md();
+        return success;
+    }
+
+    format_kind_t format_kind = types::format_tag_to_kind(tag);
+
+    /* memory_desc != 0 */
+    bool args_ok = !any_null(memory_desc)
+        && memory_desc_sanity_check(ndims, dims, data_type, format_kind);
+    if (!args_ok) return invalid_arguments;
+
+    auto md = memory_desc_t();
+    md.ndims = ndims;
+    array_copy(md.dims, dims, ndims);
+    md.data_type = data_type;
+    array_copy(md.padded_dims, dims, ndims);
+    md.format_kind = format_kind;
+
+    status_t status = success;
+    if (tag == format_tag::undef) {
+        status = invalid_arguments;
+    } else if (tag == format_tag::any) {
+        // nop
+    } else if (format_kind == format_kind::blocked) {
+        status = memory_desc_wrapper::compute_blocking(md, tag);
+    } else {
+        assert(!"unreachable");
+        status = invalid_arguments;
+    }
+
+    if (status == success)
+        *memory_desc = md;
+
+    return status;
+}
+
+status_t mkldnn_memory_desc_init_by_strides(memory_desc_t *memory_desc,
+        int ndims, const dims_t dims, data_type_t data_type,
+        const dims_t strides) {
+    if (any_null(memory_desc)) return invalid_arguments;
+    if (ndims == 0) {
+        *memory_desc = types::zero_md();
+        return success;
+    }
+
+    /* memory_desc != 0 */
+    bool args_ok = !any_null(memory_desc)
+        && memory_desc_sanity_check(ndims, dims, data_type, format_kind::any);
+    if (!args_ok) return invalid_arguments;
+
+    auto md = memory_desc_t();
+    md.ndims = ndims;
+    array_copy(md.dims, dims, ndims);
+    md.data_type = data_type;
+    array_copy(md.padded_dims, dims, ndims);
+    md.format_kind = format_kind::blocked;
+
+    dims_t default_strides = {0};
+    if (strides == nullptr) {
+        default_strides[md.ndims - 1] = 1;
+        for (int d = md.ndims - 2; d >= 0; --d)
+            default_strides[d] = default_strides[d + 1] * md.padded_dims[d + 1];
+        strides = default_strides;
+    } else {
+        /* TODO: add sanity check for the provided strides */
+    }
+
+    array_copy(md.format_desc.blocking.strides, strides, md.ndims);
+
+    *memory_desc = md;
+
+    return status::success;
+}
+
+status_t mkldnn_memory_desc_init_submemory(memory_desc_t *md,
+        const memory_desc_t *parent_md, const dims_t dims,
+        const dims_t offsets) {
+    if (any_null(md, parent_md) || !memory_desc_sanity_check(parent_md))
+        return invalid_arguments;
+
+    const memory_desc_wrapper src_d(parent_md);
+
+    for (int d = 0; d < src_d.ndims(); ++d) {
+        if (dims[d] < 0 || offsets[d] < 0
+                || (offsets[d] + dims[d] > src_d.dims()[d]))
+            return invalid_arguments;
+    }
+
+    if (src_d.format_kind() != format_kind::blocked)
+        return unimplemented;
+
+    dims_t blocks;
+    src_d.compute_blocks(blocks);
+
+    memory_desc_t dst_d = *parent_md;
+    auto &dst_d_blk = dst_d.format_desc.blocking;
+
+    /* TODO: put this into memory_desc_wrapper */
+    for (int d = 0; d < src_d.ndims(); ++d) {
+        /* very limited functionality for now */
+        const bool ok = true
+            && offsets[d] % blocks[d] == 0 /* [r1] */
+            && src_d.padded_offsets()[d] == 0
+            && (false
+                    || dims[d] % blocks[d] == 0
+                    || dims[d] < blocks[d]);
+        if (!ok)
+            return unimplemented;
+
+        const bool is_right_border = offsets[d] + dims[d] == src_d.dims()[d];
+
+        dst_d.dims[d] = dims[d];
+        dst_d.padded_dims[d] = is_right_border
+            ? src_d.padded_dims()[d] - offsets[d] : dst_d.dims[d];
+        dst_d.padded_offsets[d] = src_d.padded_offsets()[d];
+        dst_d.offset0 += /* [r1] */
+            offsets[d] / blocks[d] * dst_d_blk.strides[d];
+    }
+
+    *md = dst_d;
+
+    return success;
+}
+
+int mkldnn_memory_desc_equal(const memory_desc_t *lhs,
+        const memory_desc_t *rhs) {
+    if (lhs == rhs) return 1;
+    if (any_null(lhs, rhs)) return 0;
+    return memory_desc_wrapper(*lhs) == memory_desc_wrapper(*rhs);
+}
+
+size_t mkldnn_memory_desc_get_size(const memory_desc_t *md) {
+    if (md == nullptr) return 0;
+    return memory_desc_wrapper(*md).size();
+}
+
+status_t mkldnn_memory_create(memory_t **memory, const memory_desc_t *md,
+        engine_t *engine, void *handle) {
+    if (any_null(memory, engine)) return invalid_arguments;
+    memory_desc_t z_md = types::zero_md();
+    return engine->memory_create(memory, md ? md : &z_md, handle);
+}
+
+status_t mkldnn_memory_get_memory_desc(const memory_t *memory,
+        const memory_desc_t **md) {
+    if (any_null(memory, md)) return invalid_arguments;
+    *md = memory->md();
+    return success;
+}
+
+status_t mkldnn_memory_get_engine(const memory_t *memory, engine_t **engine) {
+    if (any_null(memory, engine)) return invalid_arguments;
+    *engine = memory->engine();
+    return success;
+}
+
+status_t mkldnn_memory_get_data_handle(const memory_t *memory,
+        void **handle) {
+    if (any_null(handle))
+        return invalid_arguments;
+    if (memory == nullptr) {
+        *handle = nullptr;
+        return success;
+    }
+    return memory->get_data_handle(handle);
+}
+
+status_t mkldnn_memory_set_data_handle(memory_t *memory, void *handle) {
+    if (any_null(memory)) return invalid_arguments;
+    return memory->set_data_handle(handle);
+}
+
+status_t mkldnn_memory_destroy(memory_t *memory) {
+    delete memory;
+    return success;
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 63 - 0
thirdparty/oidn/mkl-dnn/src/common/memory.hpp

@@ -0,0 +1,63 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MEMORY_HPP
+#define MEMORY_HPP
+
+#include <assert.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+
+struct mkldnn_memory: public mkldnn::impl::c_compatible {
+    mkldnn_memory(mkldnn::impl::engine_t *engine,
+            const mkldnn::impl::memory_desc_t *md)
+        : engine_(engine), md_(*md) {}
+    virtual ~mkldnn_memory() {}
+
+    /** allocates/initializes memory */
+    virtual mkldnn::impl::status_t init() = 0;
+
+    /** returns memory's engine */
+    mkldnn::impl::engine_t *engine() const { return engine_; }
+    /** returns memory's description */
+    const mkldnn::impl::memory_desc_t *md() const { return &md_; }
+
+    /** returns data handle */
+    virtual mkldnn::impl::status_t get_data_handle(void **handle) const = 0;
+
+    /** sets data handle */
+    virtual mkldnn::impl::status_t set_data_handle(void *handle) = 0;
+
+    /** zeros padding */
+    virtual mkldnn::impl::status_t zero_pad() const
+    { return mkldnn::impl::status::success; }
+
+protected:
+    mkldnn::impl::engine_t *engine_;
+    const mkldnn::impl::memory_desc_t md_;
+
+private:
+    mkldnn_memory() = delete;
+    mkldnn_memory(const mkldnn_memory &) = delete;
+    mkldnn_memory(mkldnn_memory &&) = delete;
+    mkldnn_memory &operator=(const mkldnn_memory &) = delete;
+    mkldnn_memory &operator=(mkldnn_memory &&) = delete;
+};
+
+#endif

+ 212 - 0
thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.cpp

@@ -0,0 +1,212 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include <initializer_list>
+
+#include "c_types_map.hpp"
+#include "memory_desc_wrapper.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+status_t fill_blocked(memory_desc_t &md,
+        std::initializer_list<int> perm,
+        std::initializer_list<int> inner_blks,
+        std::initializer_list<int> inner_idxs) {
+    const bool ok = true
+        && perm.size() == (size_t)md.ndims
+        && inner_blks.size() == inner_idxs.size();
+    if (!ok) return status::invalid_arguments;
+
+    md.offset0 = 0;
+
+    blocking_desc_t &blk = md.format_desc.blocking;
+
+    dim_t block_size = 1;
+    dims_t blocks = {0};
+    utils::array_set(blocks, 1, md.ndims);
+
+    blk.inner_nblks = (int)inner_blks.size();
+
+    int iblk = 0;
+    for (const auto &b: inner_idxs)
+        blk.inner_idxs[iblk++] = b;
+
+    iblk = 0;
+    for (const auto &b: inner_blks) {
+        int dim = blk.inner_idxs[iblk];
+        block_size *= b;
+        blocks[dim] *= b;
+        blk.inner_blks[iblk++] = b;
+    }
+
+    utils::array_set(md.padded_offsets, 0, md.ndims);
+    for (int d = 0; d < md.ndims; ++d)
+        md.padded_dims[d] = utils::rnd_up(md.dims[d], blocks[d]);
+
+    dim_t stride = block_size;
+    // if only we use C++14, the initializer_list would have rbegin()/rend()...
+    for (int d = 0; d < md.ndims; ++d)
+        stride *= md.padded_dims[d] == 0 ? 1 : md.padded_dims[d] / blocks[d];
+
+    for (const auto &d: perm) {
+        if (md.padded_dims[d] == 0) {
+             blk.strides[d] = 1;
+             continue;
+        }
+        stride /= md.padded_dims[d] / blocks[d];
+        blk.strides[d] = stride;
+    }
+
+    assert(stride == block_size);
+
+    return status::success;
+}
+
+status_t memory_desc_wrapper::compute_blocking(memory_desc_t &memory_desc,
+        format_tag_t tag)
+{
+    using namespace format_tag;
+
+    if (memory_desc.ndims == 0) return status::invalid_arguments;
+
+#   define C(tag, ... /* perm, inner_blks, inner_idxs */) \
+    case tag: return fill_blocked(memory_desc, __VA_ARGS__)
+
+    switch (tag) {
+    C(a, {0}, {}, {});
+    C(ab, {0, 1}, {}, {});
+    C(abc, {0, 1, 2}, {}, {});
+    C(abcd, {0, 1, 2, 3}, {}, {});
+    C(abcde, {0, 1, 2, 3, 4}, {}, {});
+    C(abcdef, {0, 1, 2, 3, 4, 5}, {}, {});
+    C(abdec, {0, 1, 3, 4, 2}, {}, {});
+    C(acb, {0, 2, 1}, {}, {});
+    C(acbde, {0, 2, 1, 3, 4}, {}, {});
+    C(acdb, {0, 2, 3, 1}, {}, {});
+    C(acdeb, {0, 2, 3, 4, 1}, {}, {});
+    C(ba, {1, 0}, {}, {});
+    C(bac, {1, 0, 2}, {}, {});
+    C(bacd, {1, 0, 2, 3}, {}, {});
+    C(bcda, {1, 2, 3, 0}, {}, {});
+    C(cba, {2, 1, 0}, {}, {});
+    C(cdba, {2, 3, 1, 0}, {}, {});
+    C(cdeba, {2, 3, 4, 1, 0}, {}, {});
+    C(decab, {3, 4, 2, 0, 1}, {}, {});
+
+    C(Abc4a, {0, 1, 2}, {4}, {0});
+    C(aBc4b, {0, 1, 2}, {4}, {1});
+    C(ABc4b16a4b, {0, 1, 2}, {4, 16, 4}, {1, 0, 1});
+    C(ABc4b4a, {0, 1, 2}, {4, 4}, {1, 0});
+    C(Abcd4a, {0, 1, 2, 3}, {4}, {0});
+    C(aBcd4b, {0, 1, 2, 3}, {4}, {1});
+    C(ABcd4b4a, {0, 1, 2, 3}, {4, 4}, {1, 0});
+    C(aBCd4c16b4c, {0, 1, 2, 3}, {4, 16, 4}, {2, 1, 2});
+    C(aBCd4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1});
+    C(Abcde4a, {0, 1, 2, 3, 4}, {4}, {0});
+    C(aBcde4b, {0, 1, 2, 3, 4}, {4}, {1});
+    C(ABcde4b4a, {0, 1, 2, 3, 4}, {4, 4}, {1, 0});
+    C(aBCde4c4b, {0, 1, 2, 3, 4}, {4, 4}, {2, 1});
+    C(aBcdef4b, {0, 1, 2, 3, 4, 5}, {4}, {1});
+    C(aBCdef4c4b, {0, 1, 2, 3, 4, 5}, {4, 4}, {2, 1});
+    C(aBdc4b, {0, 1, 3, 2}, {4}, {1});
+    C(aBdec4b, {0, 1, 3, 4, 2}, {4}, {1});
+    C(aBdefc4b, {0, 1, 3, 4, 5, 2}, {4}, {1});
+    C(Acb4a, {0, 2, 1}, {4}, {0});
+    C(Acdb4a, {0, 2, 3, 1}, {4}, {0});
+    C(Acdeb4a, {0, 2, 3, 4, 1}, {4}, {0});
+
+    C(Abc16a, {0, 1, 2}, {16}, {0});
+    C(ABc16a16b, {0, 1, 2}, {16, 16}, {0, 1});
+    C(aBc16b, {0, 1, 2}, {16}, {1});
+    C(ABc16b16a, {0, 1, 2}, {16, 16}, {1, 0});
+    C(ABc8a16b2a, {0, 1, 2}, {8, 16, 2}, {0, 1, 0});
+    C(ABc8a8b, {0, 1, 2}, {8, 8}, {0, 1});
+    C(aBc8b, {0, 1, 2}, {8}, {1});
+    C(ABc8b16a2b, {0, 1, 2}, {8, 16, 2}, {1, 0, 1});
+    C(ABc8b8a, {0, 1, 2}, {8, 8}, {1, 0});
+    C(Abcd16a, {0, 1, 2, 3}, {16}, {0});
+    C(ABcd16a16b, {0, 1, 2, 3}, {16, 16}, {0, 1});
+    C(aBcd16b, {0, 1, 2, 3}, {16}, {1});
+    C(ABcd16b16a, {0, 1, 2, 3}, {16, 16}, {1, 0});
+    C(aBCd16b16c, {0, 1, 2, 3}, {16, 16}, {1, 2});
+    C(aBCd16c16b, {0, 1, 2, 3}, {16, 16}, {2, 1});
+    C(ABcd4b16a4b, {0, 1, 2, 3}, {4, 16, 4}, {1, 0, 1});
+    C(ABcd8a16b2a, {0, 1, 2, 3}, {8, 16, 2}, {0, 1, 0});
+    C(ABcd8a8b, {0, 1, 2, 3}, {8, 8}, {0, 1});
+    C(aBcd8b, {0, 1, 2, 3}, {8}, {1});
+    C(ABcd8b16a2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 0, 1});
+    C(aBCd8b16c2b, {0, 1, 2, 3}, {8, 16, 2}, {1, 2, 1});
+    C(ABcd8b8a, {0, 1, 2, 3}, {8, 8}, {1, 0});
+    C(aBCd8b8c, {0, 1, 2, 3}, {8, 8}, {1, 2});
+    C(aBCd8c16b2c, {0, 1, 2, 3}, {8, 16, 2}, {2, 1, 2});
+    C(aBCd8c8b, {0, 1, 2, 3}, {8, 8}, {2, 1});
+    C(Abcde16a, {0, 1, 2, 3, 4}, {16}, {0});
+    C(ABcde16a16b, {0, 1, 2, 3, 4}, {16, 16}, {0, 1});
+    C(aBcde16b, {0, 1, 2, 3, 4}, {16}, {1});
+    C(ABcde16b16a, {0, 1, 2, 3, 4}, {16, 16}, {1, 0});
+    C(aBCde16b16c, {0, 1, 2, 3, 4}, {16, 16}, {1, 2});
+    C(aBCde16c16b, {0, 1, 2, 3, 4}, {16, 16}, {2, 1});
+    C(aBCde2c8b4c, {0, 1, 2, 3, 4}, {2, 8, 4}, {2, 1, 2});
+    C(aBCde4b4c, {0, 1, 2, 3, 4}, {4, 4}, {1, 2});
+    C(aBCde4c16b4c, {0, 1, 2, 3, 4}, {4, 16, 4}, {2, 1, 2});
+    C(Abcde8a, {0, 1, 2, 3, 4}, {8}, {0});
+    C(ABcde8a8b, {0, 1, 2, 3, 4}, {8, 8}, {0, 1});
+    C(aBcde8b, {0, 1, 2, 3, 4}, {8}, {1});
+    C(ABcde8b16a2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 0, 1});
+    C(aBCde8b16c2b, {0, 1, 2, 3, 4}, {8, 16, 2}, {1, 2, 1});
+    C(ABcde8b8a, {0, 1, 2, 3, 4}, {8, 8}, {1, 0});
+    C(aBCde8b8c, {0, 1, 2, 3, 4}, {8, 8}, {1, 2});
+    C(aBCde8c16b2c, {0, 1, 2, 3, 4}, {8, 16, 2}, {2, 1, 2});
+    C(aBCde8c8b, {0, 1, 2, 3, 4}, {8, 8}, {2, 1});
+    C(aBcdef16b, {0, 1, 2, 3, 4, 5}, {16}, {1});
+    C(aBCdef16b16c, {0, 1, 2, 3, 4, 5}, {16, 16}, {1, 2});
+    C(aBCdef16c16b, {0, 1, 2, 3, 4, 5}, {16, 16}, {2, 1});
+    C(aBCdef8b8c, {0, 1, 2, 3, 4, 5}, {8, 8}, {1, 2});
+    C(aBCdef8c16b2c, {0, 1, 2, 3, 4, 5}, {8, 16, 2}, {2, 1, 2});
+    C(aBCdef8c8b, {0, 1, 2, 3, 4, 5}, {8, 8}, {2, 1});
+    C(aBdc16b, {0, 1, 3, 2}, {16}, {1});
+    C(aBdc8b, {0, 1, 3, 2}, {8}, {1});
+    C(aBdec16b, {0, 1, 3, 4, 2}, {16}, {1});
+    C(aBdec8b, {0, 1, 3, 4, 2}, {8}, {1});
+    C(aBdefc16b, {0, 1, 3, 4, 5, 2}, {16}, {1});
+    C(aBdefc8b, {0, 1, 3, 4, 5, 2}, {8}, {1});
+    C(Acb16a, {0, 2, 1}, {16}, {0});
+    C(Acb8a, {0, 2, 1}, {8}, {0});
+    C(aCBd16b16c, {0, 2, 1, 3}, {16, 16}, {1, 2});
+    C(aCBde16b16c, {0, 2, 1, 3, 4}, {16, 16}, {1, 2});
+    C(Acdb16a, {0, 2, 3, 1}, {16}, {0});
+    C(Acdb8a, {0, 2, 3, 1}, {8}, {0});
+    C(Acdeb16a, {0, 2, 3, 4, 1}, {16}, {0});
+    C(Acdeb8a, {0, 2, 3, 4, 1}, {8}, {0});
+    C(BAc16a16b, {1, 0, 2}, {16, 16}, {0, 1});
+    C(BAcd16a16b, {1, 0, 2, 3}, {16, 16}, {0, 1});
+    default: break;
+    }
+
+#undef C
+
+    return status::invalid_arguments;
+}
+
+}
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 400 - 0
thirdparty/oidn/mkl-dnn/src/common/memory_desc_wrapper.hpp

@@ -0,0 +1,400 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MEMORY_DESC_WRAPPER_HPP
+#define MEMORY_DESC_WRAPPER_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "utils.hpp"
+
+#include "type_helpers.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+/** thin wrapper class over \struct memory_desc_t which allows easy
+ * manipulations with underlying C structure, which is taken by reference */
+struct memory_desc_wrapper: public c_compatible {
+    const memory_desc_t *md_;
+
+    /** constructor which takes a reference to a constant underlying C memory
+     * descriptor \param md */
+    memory_desc_wrapper(const memory_desc_t *md): md_(md) {}
+    memory_desc_wrapper(const memory_desc_t &md): memory_desc_wrapper(&md) {}
+
+    /* implementing attributes */
+    int ndims() const { return md_->ndims; }
+    const dims_t &dims() const { return md_->dims; }
+    data_type_t data_type() const { return md_->data_type; }
+
+    const dims_t &padded_dims() const { return md_->padded_dims; }
+    const dims_t &padded_offsets() const { return md_->padded_offsets; }
+    dim_t offset0() const { return md_->offset0; }
+
+    format_kind_t format_kind() const { return md_->format_kind; }
+
+    bool is_blocking_desc() const
+    { return format_kind() == format_kind::blocked; }
+    bool is_wino_desc() const
+    { return format_kind() == format_kind::wino; }
+    bool is_rnn_packed_desc() const
+    { return format_kind() == format_kind::rnn_packed; }
+
+    const blocking_desc_t &blocking_desc() const {
+        assert(is_blocking_desc());
+        return md_->format_desc.blocking;
+    }
+    const wino_desc_t &wino_desc() const {
+        assert(is_wino_desc());
+        return md_->format_desc.wino_desc;
+    }
+    const rnn_packed_desc_t &rnn_packed_desc() const {
+        assert(is_rnn_packed_desc());
+        return md_->format_desc.rnn_packed_desc;
+    }
+
+    const memory_extra_desc_t &extra() const { return md_->extra; }
+
+    /* some useful function */
+
+    /** returns the number of elements including padding if \param with_padding
+     * is true, and the number of data elements otherwise */
+    dim_t nelems(bool with_padding = false) const {
+        if (is_zero()) return 0;
+        return utils::array_product(
+                with_padding ? padded_dims() : dims(), ndims());
+    }
+
+    /** returns true if memory descriptor is zero */
+    bool is_zero() const { return ndims() == 0; }
+
+    /** returns true if memory descriptor contains zero as one of its dim */
+    bool has_zero_dim() const { return nelems() == 0; }
+
+    /** return the size of data type (a shortcut) */
+    size_t data_type_size() const
+    { return types::data_type_size(data_type()); }
+
+    /** return the size of data type of additional buffer */
+    size_t additional_buffer_data_size() const {
+        if (extra().flags & memory_extra_flags::compensation_conv_s8s8)
+            return sizeof(int32_t);
+        return 0;
+    }
+
+    /** return true if memory format has additional buffer */
+    bool is_additional_buffer() const {
+        return (extra().flags & memory_extra_flags::compensation_conv_s8s8);
+    }
+
+    /** returns the size of additional buffer */
+    size_t additional_buffer_size() const {
+        if (extra().flags & memory_extra_flags::compensation_conv_s8s8) {
+            int cmask = extra().compensation_mask;
+            assert(cmask == 1 || cmask == 3);
+            dim_t prod = 1;
+            for (int d = 0; d < ndims(); ++d)
+                if (cmask & (1<<d)) prod *= padded_dims()[d];
+            return prod * additional_buffer_data_size();
+        }
+
+        return 0;
+    }
+
+    /** returns the size required to store described memory
+     * note: if offset0 != 0 returns 0 (need to specify the behavior) */
+    size_t size() const {
+        if (is_zero() || has_zero_dim() || format_kind() == format_kind::any)
+            return 0;
+
+        if (format_kind() == format_kind::wino) {
+            return wino_desc().size;
+        } else if (format_kind() == format_kind::rnn_packed) {
+            return rnn_packed_desc().size;
+        } else {
+            if (offset0() != 0) return 0;
+
+            dims_t blocks = {0};
+            compute_blocks(blocks);
+
+            const auto &bd = blocking_desc();
+
+            size_t max_size = 0;
+            for (int d = 0; d < ndims(); ++d)
+                max_size = nstl::max<size_t>(max_size,
+                        padded_dims()[d] / blocks[d] * bd.strides[d]);
+
+            if (max_size == 1 && bd.inner_nblks != 0) {
+                max_size = utils::array_product(bd.inner_blks, bd.inner_nblks);
+            }
+
+            return max_size * data_type_size() + additional_buffer_size();
+        }
+    }
+
+    /** returns true if data is dense in memory */
+    bool is_dense(bool with_padding = false) const {
+        if (utils::one_of(format_kind(), format_kind::undef, format_kind::any))
+            return false;
+        return nelems(with_padding) * data_type_size() == size();
+    }
+
+    /** returns true if memory desc is fully defined */
+    bool is_defined() const { return format_kind() != format_kind::any; }
+
+    /** returns true if the only (potentially) padded dim is \param dim */
+    bool only_padded_dim(int dim) const {
+        for (int d = 0; d < ndims(); ++d)
+            if (d != dim && dims()[d] != padded_dims()[d])
+                return false;
+        return true;
+    }
+
+    /** returns true if memory desc has blocked layout and block dims are 1s */
+    bool is_plain() const {
+        if (!is_blocking_desc()) return false;
+        return blocking_desc().inner_nblks == 0;
+    }
+
+    /** returns overall block sizes */
+    void compute_blocks(dims_t blocks) const {
+        if (!is_blocking_desc()) {
+            utils::array_set(blocks, 0, ndims());
+            return;
+        }
+
+        utils::array_set(blocks, 1, ndims());
+
+        const auto &bd = blocking_desc();
+        for (int iblk = 0; iblk < bd.inner_nblks; ++iblk)
+            blocks[bd.inner_idxs[iblk]] *= bd.inner_blks[iblk];
+    }
+
+    /* comparison section */
+
+    bool operator==(const memory_desc_wrapper &rhs) const
+    { return *this->md_ == *rhs.md_; }
+    bool operator!=(const memory_desc_wrapper &rhs) const
+    { return !operator==(rhs); }
+    bool operator==(const memory_desc_t &rhs) const
+    { return operator==(memory_desc_wrapper(rhs)); }
+    bool operator!=(const memory_desc_t &rhs) const
+    { return !operator==(rhs); }
+
+    /** returns true if data (w/o padding if with_padding == false and w/
+     * padding otherwise) have the same physical structure, i.e. dimensions,
+     * strides, and blocked structure. Depending on with_data_type flag
+     * data_type is taken or not taken into account. dim_start allows to check
+     * similarity for the logical part of data [dim_start .. ndims()].
+     * CAUTION: format kind any and undef are not similar to whatever, hence the
+     * following statement might be true: lhs == rhs && !lhs.similar_to(rhs) */
+    /* TODO: revise */
+    bool similar_to(const memory_desc_wrapper &rhs,
+            bool with_padding = true, bool with_data_type = true,
+            int dim_start = 0) const;
+
+    /** returns true if one memory can be reordered to another */
+    bool consistent_with(const memory_desc_wrapper &rhs) const;
+
+    /** returns true if the memory desc corresponds to the given format tag and
+     * strides.
+     * @sa memory_desc_matches_tag */
+    bool matches_tag(format_tag_t tag, const dims_t strides = nullptr) const {
+        return memory_desc_matches_tag(*md_, tag, strides);
+    }
+
+    /** returns matching tag (or undef if match is not found)
+     * XXX: This is a workaround that eventually should go away! */
+    template <typename... Tags>
+    format_tag_t matches_one_of_tag(Tags ...tags) const {
+        for (const auto tag: {tags...}) {
+            if (memory_desc_matches_tag(*md_, tag))
+                return tag;
+        }
+        return format_tag::undef;
+    }
+
+    /* offset section */
+
+    /** returns physical offset by logical one. logical offset is represented by
+     * an array \param pos. if \param is_pos_padded is true \param pos
+     * represents the position in already padded area */
+    dim_t off_v(const dims_t pos, bool is_pos_padded = false) const {
+        assert(is_blocking_desc());
+        const blocking_desc_t &blk = blocking_desc();
+
+        dims_t pos_copy = {0};
+        for (int d = 0; d < ndims(); ++d)
+            pos_copy[d] = pos[d] + (is_pos_padded ? 0 : padded_offsets()[d]);
+
+        dim_t phys_offset = offset0();
+
+        if (blk.inner_nblks > 0) {
+            dim_t blk_stride = 1;
+            for (int iblk = blk.inner_nblks - 1; iblk >= 0; --iblk) {
+                const int d = blk.inner_idxs[iblk];
+                const dim_t p = pos_copy[d] % blk.inner_blks[iblk];
+
+                phys_offset += p * blk_stride;
+
+                pos_copy[d] /= blk.inner_blks[iblk];
+
+                blk_stride *= blk.inner_blks[iblk];
+            }
+        }
+
+        for (int d = 0; d < ndims(); ++d) {
+            const dim_t p = pos_copy[d];
+            phys_offset += p * blk.strides[d];
+        }
+
+        return phys_offset;
+    }
+
+    /** returns physical offset by logical one. logical offset is represented by
+     * a scalar \param l_offset. if \param is_pos_padded is true, \param
+     * l_offset represents logical offset in already padded area */
+    dim_t off_l(dim_t l_offset, bool is_pos_padded = false) const {
+        assert(is_blocking_desc());
+        dims_t pos;
+        for (int rd = 0; rd < ndims(); ++rd) {
+            const int d = ndims() - 1 - rd;
+            const dim_t cur_dim = is_pos_padded ? padded_dims()[d] : dims()[d];
+            pos[d] = l_offset % cur_dim;
+            l_offset /= cur_dim;
+        }
+        return off_v(pos, is_pos_padded);
+    }
+
+    /** returns physical offset by logical one. logical offset is represented by
+     * a tuple of indices (\param xn, ..., \param x1, \param x0) */
+    template<typename... Args>
+    dim_t off(Args... args) const {
+        assert(sizeof...(args) == ndims());
+        dims_t pos = { args... };
+        return off_v(pos, false);
+    }
+
+    /** returns physical offset by logical one. logical offset is represented by
+     * a tuple of indices (\param xn, ..., \param x1, \param x0) in already
+     * padded area */
+    template<typename... Args>
+    dim_t off_padding(Args... args) const {
+        assert(sizeof...(args) == ndims());
+        dims_t pos = { args... };
+        return off_v(pos, true);
+    }
+
+    /** returns physical offset by logical one. Logical offset is represented by
+     * a tuple of block indices (\param bn, ..., \param b1, \param b0). It is a
+     * user responsibility to adjust the result to get offset within blocks */
+    template<typename ...Args>
+    dim_t blk_off(Args... args) const {
+        return _blk_off<sizeof...(args), Args...>(args...);
+    }
+
+    template<bool skip_first, typename T, typename ...Args>
+    dim_t blk_off(T xn, Args... args) const {
+        return skip_first
+            ? blk_off<Args...>(args...)
+            : blk_off<T, Args...>(xn, args...);
+    }
+
+    /* static functions section */
+    /* TODO: replace with non-static, once md_ becomes non-const ref */
+
+    static status_t compute_blocking(memory_desc_t &memory_desc,
+            format_tag_t tag);
+
+private:
+    /* TODO: put logical_offset in utils */
+    template<typename T>
+    dim_t logical_offset(T x0) const { return x0; }
+
+    template<typename T, typename... Args>
+    dim_t logical_offset(T xn, Args... args) const {
+        const size_t n_args = sizeof...(args);
+        return xn * utils::array_product<n_args>(
+                &dims()[ndims() - n_args]) + logical_offset(args...);
+    }
+
+    template<int ORIG_LEN, typename ...Void>
+    dim_t _blk_off() const { return offset0(); }
+
+    template<int ORIG_LEN, typename T, typename ...Args>
+    dim_t _blk_off(T xc, Args ...args) const {
+        assert(is_blocking_desc());
+        constexpr int dc = ORIG_LEN - sizeof...(args) - 1;
+        return xc * blocking_desc().strides[dc]
+            + _blk_off<ORIG_LEN, Args...>(args...);
+    }
+};
+
+inline bool memory_desc_wrapper::similar_to(const memory_desc_wrapper &rhs,
+        bool with_padding, bool with_data_type, int dim_start) const {
+    using namespace utils;
+
+    if (one_of(format_kind(), format_kind::undef, format_kind::any))
+        return false;
+    if (is_wino_desc() || is_rnn_packed_desc())
+        return false;
+
+    const int ds = dim_start;
+    const auto &blk = blocking_desc();
+    const auto &r_blk = rhs.blocking_desc();
+
+    return ndims() == rhs.ndims()
+        && dim_start <= ndims() /* guard */
+        && format_kind() == rhs.format_kind()
+        && IMPLICATION(with_data_type, data_type() == rhs.data_type())
+        && array_cmp(dims() + ds, rhs.dims() + ds, ndims() - ds)
+        && array_cmp(blk.strides + ds, r_blk.strides + ds, ndims() - ds)
+        && blk.inner_nblks == r_blk.inner_nblks
+        && array_cmp(blk.inner_blks, r_blk.inner_blks, blk.inner_nblks)
+        && array_cmp(blk.inner_idxs, r_blk.inner_idxs, blk.inner_nblks)
+        && IMPLICATION(with_padding, true
+                && array_cmp(padded_dims() + ds, rhs.padded_dims() + ds,
+                    ndims() - ds)
+                && array_cmp(padded_offsets() + ds, rhs.padded_offsets() + ds,
+                    ndims() - ds));
+}
+
+inline bool memory_desc_wrapper::consistent_with(
+        const memory_desc_wrapper &rhs) const {
+    if (ndims() == rhs.ndims()) {
+        for (int d = 0; d < ndims(); ++d) {
+            if (dims()[d] != rhs.dims()[d]) return false;
+        }
+        return true;
+    } else {
+        /* TODO: revise.
+         * is the following possible?
+         * [1, a, b] <--reorder--> [a, b]
+         * [a, 1, b] <--reorder--> [a, b]
+         * not, at least for now */
+        return false;
+    }
+}
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 295 - 0
thirdparty/oidn/mkl-dnn/src/common/memory_tracking.hpp

@@ -0,0 +1,295 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MEMORY_TRACKING_HPP
+#define MEMORY_TRACKING_HPP
+
+#include <assert.h>
+#include <unordered_map>
+
+#include "nstl.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+namespace memory_tracking {
+
+/* Memory tracking capabilities
+ *
+ * The main purpose of this header file is to provide uniform way to register
+ * required memory for a scratchpad at a primitive descriptor creation time
+ * and then easily access it having only the base address of the scratchpad.
+ *
+ * Primitives might contain multiple disjoint parts that require temporary
+ * buffers (known as scratchpad) during their execution. A primitive descriptor
+ * should summarize all the needs into one single number -- the buffer size
+ * that would be requested from a user. At execution time, the corresponding
+ * primitive will receive a base pointer to a scratchpad. It then needs to
+ * provide each part of algorithm the corresponding piece of memory. Three main
+ * challenges here are:
+ * 1. Track correct offset (from the base scratchpad address) for each piece
+ * 2. Algorithm might require that different memory pieces to be aligned, so
+ *    the scratchpad size is no more just a sum of size of the corresponding
+ *    subparts.
+ * 3. While a primitive is responsible for its scratchpad, the implementation
+ *    might use some other basic blocks (e.g. cpu_reducer) that also require
+ *    scratchpad memory. So there should be a simple way of passing the
+ *    information back and force between the main algorithm (a primitive) and
+ *    auxiliary stuff that lives completely separately from it (e.g. reducer).
+ *
+ * To address these challenges this header file provides 3 structures:
+ * 1. registry_t  -- the class the stores the information about requested
+ *                   memory. The information includes required size and desired
+ *                   alignment for each piece. This class is also responsible
+ *                   for computing the right offset to a given piece using the
+ *                   base pointer.
+ *                   This class is basically a ledger with all entries.
+ *                   Lives in primitive descriptors.
+ *
+ * 2. registrar_t -- the interface to a registry_t to book memory. Used at
+ *                   primitive descriptor creation time only. Contains a
+ *                   reference to the corresponding *mutable* registry.
+ *                   Always modifiable.
+ *                   Allows chaining (using prefixes).
+ *
+ * 3. grantor_t   -- the interface to a registry_t to access memory. Used at
+ *                   primitive execution time only. Contains a reference to
+ *                   the corresponding *constant* registry and base pointer.
+ *                   Always constant.
+ *                   Allows chaining (using prefixes).
+ *
+ * Both registrar_t and grantor_t allow chaining with extra prefix provided.
+ * The feature is useful when a primitive offload a part of computations to
+ * some other primitives which require their own scratchpad space
+ * (e.g. reducer). Prefixes are used to avoid key collision in cases when
+ * multiple sub-primitive (e.g. multiple reducers) are used.
+ *
+ * A short example below demonstrates how to use aforementioned classes. In it
+ * the main primitive is convolution that uses scratchpad for keeping padded
+ * bias. It also needs a reducer, that needs its own space as well.
+ *
+ *  ``` c++
+ *  struct reducer_t {
+ *      static void init(registrar_t &scratchpad) {
+ *          // preserve space for the reduction (one page aligned)
+ *          scratchpad.book(key_space, sizeof(float) * 980 * 1024, 4096);
+ *      }
+ *
+ *      void exec(const grantor_t &scratchpad) {
+ *          // get the pointer to preserved space. scratchpad came from
+ *          // upper primitive (convolution in this example)
+ *          auto space = scratchpad.get<float>(key_reducer_space);
+ *
+ *          space[:] += ...;
+ *      }
+ *  };
+ *
+ *  struct conv_t {
+ *      struct pd_t {
+ *          void init() {
+ *              registrar_t scratchpad(scratchpad_registry_);
+ *
+ *              // preserve a space for padded bias (using default alignment)
+ *              scratchpad.book(key_conv_padded_bias, 128);
+ *
+ *              // create a proxy registrar for the reducer All entries made
+ *              // by reducer would live in convolution's registry, but would
+ *              // have their own `prefix`, so no interference with conv's
+ *              // buffers.
+ *              registrar_t reducer_scratchpad(scratchpad, prefix_reducer);
+ *
+ *              reducer_t::init(reducer_scratchpad);
+ *          }
+ *
+ *          registry_t scratchpad_registry_;
+ *      }
+ *
+ *      void exec() {
+ *          // get the base pointer to a scratchpad memory from a user
+ *          void *scratchpad_ptr = this->input(MKLDNN_MEM_SCRATCHPAD);
+ *
+ *          // create a grantor to the scratchpad (and provide the base
+ *          // pointer).
+ *          grantor_t scratchpad(pd()->scratchpad_registry_, scratchpad_ptr);
+ *
+ *          // access the padded_bias (need only key name and the grantor)
+ *          auto padded_bias = scratchpad.get<float>(key_conv_padded_bias);
+ *
+ *          // to give the `right` grantor to reducer we need to add the
+ *          // corresponding prefix, so that reducer would be able to access
+ *          // its keys. The call is very similar to the one in pd_t::init
+ *          // with only difference in types: grantor_t vs registrar_t.
+ *          grantor_t reducer_scratchpad(scratchpad, prefix_reducer);
+ *          reducer->exec(reducer_scratchpad);
+ *      }
+ *  };
+ *  ```
+ */
+
+
+/* namespace with common keys and prefixes */
+namespace names {
+enum {
+    key_none = 0,
+    key_bnorm_tmp_mean,
+    key_bnorm_tmp_var,
+    key_bnorm_tmp_diff_ss,
+    key_bnorm_tmp_stats,
+    key_bnorm_reduction,
+    key_concat_iptrs,
+    key_concat_istrides,
+    key_concat_nelems,
+    key_concat_optrs,
+    key_conv_adjusted_scales,
+    key_conv_bia_reduction,
+    key_conv_gemm_col,
+    key_conv_gemm_imtr,
+    key_conv_int_dat_in_acc_dt,
+    key_conv_padded_bias,
+    key_conv_rtus_space,
+    key_conv_tr_diff_dst,
+    key_conv_tr_diff_dst_bctx,
+    key_conv_tr_src,
+    key_conv_tr_src_bctx,
+    key_conv_wei_reduction,
+    key_conv_wei_bia_reduction,
+    key_conv_wei_bia_reduction_bctx,
+    key_iprod_int_dat_in_acc_dt,
+    key_reducer_space,
+    key_reducer_space_bctx,
+    key_reorder_wino_plain,
+    key_reorder_wino_transform_space,
+    key_reorder_rnn_weights_quantization,
+    key_reorder_rnn_weights_reduction,
+    key_rnn_space,
+    key_rnn_ptrs_bia,
+    key_rnn_ptrs_wei_layer,
+    key_rnn_ptrs_wei_iter,
+    key_softmax_reduction,
+    key_wino_U,
+    key_wino_V,
+    key_wino_M,
+    key_barrier,
+};
+
+enum {
+    prefix_none = 0,
+    prefix_reducer_bia,
+    prefix_reducer_wei,
+};
+}
+
+// level 0: 00 00 00 xxx
+// level 1: 00 00 aa xxx
+// level 2: 00 aa bb xxx
+// level 3: aa bb cc xxx
+// max # of levels: 3 + 1 (base_level)
+// here:
+//      xxx        : [1 ..    MAX_KEY) : key
+//      aa, bb, cc : [1 .. MAX_PREFIX) : prefixes for levels 1, 2, and 3
+
+using key_t = uint32_t;
+enum { MAX_KEY = (1u << 10), MAX_PREFIX = (1u << 7), };
+
+/// generates global key based on a prefix and a local key
+inline key_t make_key(key_t prefix, key_t key) { return prefix + key; }
+
+/// generates global prefix based on the global parent and the local ones
+inline key_t make_prefix(key_t parent_prefix, key_t prefix)
+{ return MAX_PREFIX * parent_prefix + MAX_KEY * prefix; }
+
+struct registrar_t;
+struct grantor_t;
+
+struct registry_t {
+    void book(const key_t &key, size_t size, size_t alignment) {
+        if (size == 0) return;
+        assert(offset_map_.count(key) == 0);
+
+        size = utils::rnd_up(size, minimal_alignment);
+        alignment = nstl::max<size_t>(alignment, minimal_alignment);
+        offset_map_[key] = entry_t{size_, size, alignment};
+
+        size_ += size + alignment - minimal_alignment;
+    }
+
+    void *get(const key_t &key, void *base_ptr) const {
+        if (base_ptr == nullptr) { assert(size() == 0); return nullptr; }
+        if (offset_map_.count(key) != 1) return nullptr;
+
+        const auto &e = offset_map_.at(key);
+        base_ptr = utils::align_ptr<void>(base_ptr, minimal_alignment);
+        char *ptr = (char *)base_ptr + e.offset;
+        return utils::align_ptr<void>(ptr, e.alignment);
+    }
+
+    size_t size() const
+    { return size_ > 0 ? size_ + minimal_alignment - 1 : 0; }
+
+    registrar_t registrar();
+    grantor_t grantor(void *base_ptr) const;
+
+protected:
+    enum { minimal_alignment = 64 };
+    struct entry_t { size_t offset, size, alignment; };
+
+    std::unordered_map<key_t, entry_t> offset_map_;
+    size_t size_ = 0;
+};
+
+struct registrar_t {
+    enum { default_alignment = 64 };
+
+    registrar_t(registry_t &registry): registry_(registry), prefix_(0) {}
+    registrar_t(registrar_t &parent, const key_t &prefix)
+        : registry_(parent.registry_)
+        , prefix_(make_prefix(parent.prefix_, prefix)) {}
+
+    void book(const key_t &key, size_t size,
+            size_t alignment = default_alignment)
+    { registry_.book(make_key(prefix_, key), size, alignment); }
+
+protected:
+    registry_t &registry_;
+    const key_t prefix_;
+};
+
+struct grantor_t {
+    grantor_t(const registry_t &registry, void *base_ptr)
+        : registry_(registry), prefix_(0), base_ptr_(base_ptr) {}
+    grantor_t(const grantor_t &parent, const key_t &prefix)
+        : registry_(parent.registry_)
+        , prefix_(make_prefix(parent.prefix_, prefix))
+        , base_ptr_(parent.base_ptr_) {}
+
+    template <typename T = void> T *get(const key_t &key) const
+    { return (T *)registry_.get(make_key(prefix_, key), base_ptr_); }
+
+protected:
+    const registry_t &registry_;
+    const key_t prefix_;
+    void *base_ptr_;
+};
+
+inline registrar_t registry_t::registrar() { return registrar_t(*this); }
+inline grantor_t registry_t::grantor(void *base_ptr) const
+{ return grantor_t(*this, base_ptr); }
+
+}
+}
+}
+
+#endif

+ 131 - 0
thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug.cpp

@@ -0,0 +1,131 @@
+/*******************************************************************************
+* Copyright 2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include <stdio.h>
+#include <cinttypes>
+
+#include "mkldnn_debug.h"
+#include "mkldnn_types.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#define DPRINT(...) do { \
+    int l = snprintf(str + written_len, str_len, __VA_ARGS__); \
+    if (l < 0) return l; \
+    if ((size_t)l >= str_len) return -1; \
+    written_len += l; str_len -= l; \
+} while(0)
+
+int mkldnn_md2fmt_str(char *str, size_t str_len,
+        const mkldnn_memory_desc_t *mdesc) {
+    using namespace mkldnn::impl;
+
+    if (str == nullptr || str_len <= 1u)
+        return -1;
+
+    int written_len = 0;
+
+    if (mdesc == nullptr) {
+        DPRINT("%s::%s::",
+                mkldnn_dt2str(data_type::undef),
+                mkldnn_fmt_kind2str(format_kind::undef));
+        return written_len;
+    }
+
+    memory_desc_wrapper md(mdesc);
+
+    DPRINT("%s:", mkldnn_dt2str(md.data_type()));
+
+    bool padded_dims = false, padded_offsets = false;
+    for (int d = 0; d < md.ndims(); ++d) {
+        if (md.dims()[d] != md.padded_dims()[d]) padded_dims = true;
+        if (md.padded_offsets()[d] != 0) padded_offsets = true;
+    }
+    bool offset0 = md.offset0();
+    DPRINT("%s%s%s:",
+            padded_dims ? "p" : "",
+            padded_offsets ? "o" : "",
+            offset0 ? "0" : "");
+
+    DPRINT("%s:", mkldnn_fmt_kind2str(md.format_kind()));
+
+    if (!md.is_blocking_desc()) {
+        /* TODO: extend */
+        DPRINT("%s:", "");
+    } else {
+        const auto &blk = md.blocking_desc();
+
+        dims_t blocks;
+        md.compute_blocks(blocks);
+
+        char dim_chars[MKLDNN_MAX_NDIMS + 1];
+
+        bool plain = true;
+        for (int d = 0; d < md.ndims(); ++d) {
+            dim_chars[d] = (blocks[d] == 1 ? 'a' : 'A') + (char)d;
+            if (blocks[d] != 1) plain = false;
+        }
+
+        dims_t strides;
+        utils::array_copy(strides, blk.strides, md.ndims());
+        utils::simultaneous_sort(strides, dim_chars, md.ndims(),
+                [](dim_t a, dim_t b) { return b - a; });
+
+        dim_chars[md.ndims()] = '\0';
+        DPRINT("%s", dim_chars);
+
+        if (!plain) {
+            for (int iblk = 0; iblk < blk.inner_nblks; ++iblk) {
+                DPRINT("%d%c", (int)blk.inner_blks[iblk],
+                        'a' + (char)blk.inner_idxs[iblk]);
+            }
+        }
+
+        DPRINT("%s", ":");
+    }
+
+    DPRINT("f%lx", (long)md.extra().flags);
+
+    return written_len;
+}
+
+int mkldnn_md2dim_str(char *str, size_t str_len,
+        const mkldnn_memory_desc_t *mdesc) {
+    using namespace mkldnn::impl;
+
+    if (str == nullptr || str_len <= 1)
+        return -1;
+
+    int written_len = 0;
+
+    if (mdesc == nullptr || mdesc->ndims == 0) {
+        DPRINT("%s", "");
+        return written_len;
+    }
+
+    memory_desc_wrapper md(mdesc);
+
+    for (int d = 0; d < md.ndims() - 1; ++d)
+        DPRINT("%" PRId64 "x", md.dims()[d]);
+    DPRINT("%" PRId64, md.dims()[md.ndims() - 1]);
+
+    return written_len;
+}
+
+#undef  DPRINT

+ 365 - 0
thirdparty/oidn/mkl-dnn/src/common/mkldnn_debug_autogenerated.cpp

@@ -0,0 +1,365 @@
+/*******************************************************************************
+* Copyright 2018-2019 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+/* DO NOT EDIT, AUTO-GENERATED */
+
+#include <assert.h>
+
+#include "mkldnn_debug.h"
+#include "mkldnn_types.h"
+
+const char *mkldnn_status2str(mkldnn_status_t v) {
+    if (v == mkldnn_success) return "success";
+    if (v == mkldnn_out_of_memory) return "out_of_memory";
+    if (v == mkldnn_try_again) return "try_again";
+    if (v == mkldnn_invalid_arguments) return "invalid_arguments";
+    if (v == mkldnn_not_ready) return "not_ready";
+    if (v == mkldnn_unimplemented) return "unimplemented";
+    if (v == mkldnn_iterator_ends) return "iterator_ends";
+    if (v == mkldnn_runtime_error) return "runtime_error";
+    if (v == mkldnn_not_required) return "not_required";
+    assert(!"unknown status");
+    return "unknown status";
+}
+
+const char *mkldnn_dt2str(mkldnn_data_type_t v) {
+    if (v == mkldnn_data_type_undef) return "undef";
+    if (v == mkldnn_f32) return "f32";
+    if (v == mkldnn_s32) return "s32";
+    if (v == mkldnn_s8) return "s8";
+    if (v == mkldnn_u8) return "u8";
+    assert(!"unknown dt");
+    return "unknown dt";
+}
+
+const char *mkldnn_fmt_kind2str(mkldnn_format_kind_t v) {
+    if (v == mkldnn_format_kind_undef) return "undef";
+    if (v == mkldnn_format_kind_any) return "any";
+    if (v == mkldnn_blocked) return "blocked";
+    if (v == mkldnn_format_kind_wino) return "wino";
+    if (v == mkldnn_format_kind_rnn_packed) return "rnn_packed";
+    assert(!"unknown fmt_kind");
+    return "unknown fmt_kind";
+}
+
+const char *mkldnn_fmt_tag2str(mkldnn_format_tag_t v) {
+    if (v == mkldnn_format_tag_undef) return "undef";
+    if (v == mkldnn_format_tag_any) return "format_tag_any";
+    if (v == mkldnn_a) return "a";
+    if (v == mkldnn_ab) return "ab";
+    if (v == mkldnn_abc) return "abc";
+    if (v == mkldnn_abcd) return "abcd";
+    if (v == mkldnn_abcde) return "abcde";
+    if (v == mkldnn_abcdef) return "abcdef";
+    if (v == mkldnn_abdec) return "abdec";
+    if (v == mkldnn_acb) return "acb";
+    if (v == mkldnn_acbde) return "acbde";
+    if (v == mkldnn_acdb) return "acdb";
+    if (v == mkldnn_acdeb) return "acdeb";
+    if (v == mkldnn_ba) return "ba";
+    if (v == mkldnn_bac) return "bac";
+    if (v == mkldnn_bacd) return "bacd";
+    if (v == mkldnn_bcda) return "bcda";
+    if (v == mkldnn_cba) return "cba";
+    if (v == mkldnn_cdba) return "cdba";
+    if (v == mkldnn_cdeba) return "cdeba";
+    if (v == mkldnn_decab) return "decab";
+    if (v == mkldnn_Abc16a) return "Abc16a";
+    if (v == mkldnn_ABc16a16b) return "ABc16a16b";
+    if (v == mkldnn_aBc16b) return "aBc16b";
+    if (v == mkldnn_ABc16b16a) return "ABc16b16a";
+    if (v == mkldnn_Abc4a) return "Abc4a";
+    if (v == mkldnn_aBc4b) return "aBc4b";
+    if (v == mkldnn_ABc4b16a4b) return "ABc4b16a4b";
+    if (v == mkldnn_ABc4b4a) return "ABc4b4a";
+    if (v == mkldnn_ABc8a16b2a) return "ABc8a16b2a";
+    if (v == mkldnn_ABc8a8b) return "ABc8a8b";
+    if (v == mkldnn_aBc8b) return "aBc8b";
+    if (v == mkldnn_ABc8b16a2b) return "ABc8b16a2b";
+    if (v == mkldnn_ABc8b8a) return "ABc8b8a";
+    if (v == mkldnn_Abcd16a) return "Abcd16a";
+    if (v == mkldnn_ABcd16a16b) return "ABcd16a16b";
+    if (v == mkldnn_aBcd16b) return "aBcd16b";
+    if (v == mkldnn_ABcd16b16a) return "ABcd16b16a";
+    if (v == mkldnn_aBCd16b16c) return "aBCd16b16c";
+    if (v == mkldnn_aBCd16c16b) return "aBCd16c16b";
+    if (v == mkldnn_Abcd4a) return "Abcd4a";
+    if (v == mkldnn_aBcd4b) return "aBcd4b";
+    if (v == mkldnn_ABcd4b16a4b) return "ABcd4b16a4b";
+    if (v == mkldnn_ABcd4b4a) return "ABcd4b4a";
+    if (v == mkldnn_aBCd4c16b4c) return "aBCd4c16b4c";
+    if (v == mkldnn_aBCd4c4b) return "aBCd4c4b";
+    if (v == mkldnn_ABcd8a16b2a) return "ABcd8a16b2a";
+    if (v == mkldnn_ABcd8a8b) return "ABcd8a8b";
+    if (v == mkldnn_aBcd8b) return "aBcd8b";
+    if (v == mkldnn_ABcd8b16a2b) return "ABcd8b16a2b";
+    if (v == mkldnn_aBCd8b16c2b) return "aBCd8b16c2b";
+    if (v == mkldnn_ABcd8b8a) return "ABcd8b8a";
+    if (v == mkldnn_aBCd8b8c) return "aBCd8b8c";
+    if (v == mkldnn_aBCd8c16b2c) return "aBCd8c16b2c";
+    if (v == mkldnn_aBCd8c8b) return "aBCd8c8b";
+    if (v == mkldnn_Abcde16a) return "Abcde16a";
+    if (v == mkldnn_ABcde16a16b) return "ABcde16a16b";
+    if (v == mkldnn_aBcde16b) return "aBcde16b";
+    if (v == mkldnn_ABcde16b16a) return "ABcde16b16a";
+    if (v == mkldnn_aBCde16b16c) return "aBCde16b16c";
+    if (v == mkldnn_aBCde16c16b) return "aBCde16c16b";
+    if (v == mkldnn_aBCde2c8b4c) return "aBCde2c8b4c";
+    if (v == mkldnn_Abcde4a) return "Abcde4a";
+    if (v == mkldnn_aBcde4b) return "aBcde4b";
+    if (v == mkldnn_ABcde4b4a) return "ABcde4b4a";
+    if (v == mkldnn_aBCde4b4c) return "aBCde4b4c";
+    if (v == mkldnn_aBCde4c16b4c) return "aBCde4c16b4c";
+    if (v == mkldnn_aBCde4c4b) return "aBCde4c4b";
+    if (v == mkldnn_Abcde8a) return "Abcde8a";
+    if (v == mkldnn_ABcde8a8b) return "ABcde8a8b";
+    if (v == mkldnn_ABcde8b16a2b) return "ABcde8b16a2b";
+    if (v == mkldnn_aBCde8b16c2b) return "aBCde8b16c2b";
+    if (v == mkldnn_ABcde8b8a) return "ABcde8b8a";
+    if (v == mkldnn_aBCde8b8c) return "aBCde8b8c";
+    if (v == mkldnn_aBCde8c16b2c) return "aBCde8c16b2c";
+    if (v == mkldnn_aBCde8c8b) return "aBCde8c8b";
+    if (v == mkldnn_aBcdef16b) return "aBcdef16b";
+    if (v == mkldnn_aBCdef16b16c) return "aBCdef16b16c";
+    if (v == mkldnn_aBCdef16c16b) return "aBCdef16c16b";
+    if (v == mkldnn_aBcdef4b) return "aBcdef4b";
+    if (v == mkldnn_aBCdef4c4b) return "aBCdef4c4b";
+    if (v == mkldnn_aBCdef8b8c) return "aBCdef8b8c";
+    if (v == mkldnn_aBCdef8c16b2c) return "aBCdef8c16b2c";
+    if (v == mkldnn_aBCdef8c8b) return "aBCdef8c8b";
+    if (v == mkldnn_aBdc16b) return "aBdc16b";
+    if (v == mkldnn_aBdc4b) return "aBdc4b";
+    if (v == mkldnn_aBdc8b) return "aBdc8b";
+    if (v == mkldnn_aBdec16b) return "aBdec16b";
+    if (v == mkldnn_aBdec4b) return "aBdec4b";
+    if (v == mkldnn_aBdec8b) return "aBdec8b";
+    if (v == mkldnn_aBdefc16b) return "aBdefc16b";
+    if (v == mkldnn_aBdefc4b) return "aBdefc4b";
+    if (v == mkldnn_aBdefc8b) return "aBdefc8b";
+    if (v == mkldnn_Acb16a) return "Acb16a";
+    if (v == mkldnn_Acb4a) return "Acb4a";
+    if (v == mkldnn_Acb8a) return "Acb8a";
+    if (v == mkldnn_aCBd16b16c) return "aCBd16b16c";
+    if (v == mkldnn_aCBde16b16c) return "aCBde16b16c";
+    if (v == mkldnn_Acdb16a) return "Acdb16a";
+    if (v == mkldnn_Acdb4a) return "Acdb4a";
+    if (v == mkldnn_Acdb8a) return "Acdb8a";
+    if (v == mkldnn_Acdeb16a) return "Acdeb16a";
+    if (v == mkldnn_Acdeb4a) return "Acdeb4a";
+    if (v == mkldnn_Acdeb8a) return "Acdeb8a";
+    if (v == mkldnn_BAc16a16b) return "BAc16a16b";
+    if (v == mkldnn_BAcd16a16b) return "BAcd16a16b";
+    if (v == mkldnn_format_tag_last) return "format_tag_last";
+    if (v == mkldnn_x) return "x";
+    if (v == mkldnn_nc) return "nc";
+    if (v == mkldnn_cn) return "cn";
+    if (v == mkldnn_ncw) return "ncw";
+    if (v == mkldnn_nwc) return "nwc";
+    if (v == mkldnn_nchw) return "nchw";
+    if (v == mkldnn_nhwc) return "nhwc";
+    if (v == mkldnn_chwn) return "chwn";
+    if (v == mkldnn_ncdhw) return "ncdhw";
+    if (v == mkldnn_ndhwc) return "ndhwc";
+    if (v == mkldnn_oi) return "oi";
+    if (v == mkldnn_io) return "io";
+    if (v == mkldnn_oiw) return "oiw";
+    if (v == mkldnn_wio) return "wio";
+    if (v == mkldnn_oihw) return "oihw";
+    if (v == mkldnn_hwio) return "hwio";
+    if (v == mkldnn_ihwo) return "ihwo";
+    if (v == mkldnn_iohw) return "iohw";
+    if (v == mkldnn_oidhw) return "oidhw";
+    if (v == mkldnn_dhwio) return "dhwio";
+    if (v == mkldnn_goiw) return "goiw";
+    if (v == mkldnn_goihw) return "goihw";
+    if (v == mkldnn_hwigo) return "hwigo";
+    if (v == mkldnn_giohw) return "giohw";
+    if (v == mkldnn_goidhw) return "goidhw";
+    if (v == mkldnn_tnc) return "tnc";
+    if (v == mkldnn_ntc) return "ntc";
+    if (v == mkldnn_ldsnc) return "ldsnc";
+    if (v == mkldnn_ldigo) return "ldigo";
+    if (v == mkldnn_ldgoi) return "ldgoi";
+    if (v == mkldnn_ldgo) return "ldgo";
+    if (v == mkldnn_nCdhw16c) return "nCdhw16c";
+    if (v == mkldnn_nCdhw4c) return "nCdhw4c";
+    if (v == mkldnn_nCdhw8c) return "nCdhw8c";
+    if (v == mkldnn_nChw16c) return "nChw16c";
+    if (v == mkldnn_nChw4c) return "nChw4c";
+    if (v == mkldnn_nChw8c) return "nChw8c";
+    if (v == mkldnn_nCw16c) return "nCw16c";
+    if (v == mkldnn_nCw4c) return "nCw4c";
+    if (v == mkldnn_nCw8c) return "nCw8c";
+    if (v == mkldnn_IOw16o16i) return "IOw16o16i";
+    if (v == mkldnn_OIw16i16o) return "OIw16i16o";
+    if (v == mkldnn_OIw16o16i) return "OIw16o16i";
+    if (v == mkldnn_Oiw16o) return "Oiw16o";
+    if (v == mkldnn_OIw4i16o4i) return "OIw4i16o4i";
+    if (v == mkldnn_OIw4i4o) return "OIw4i4o";
+    if (v == mkldnn_Oiw4o) return "Oiw4o";
+    if (v == mkldnn_OIw8i16o2i) return "OIw8i16o2i";
+    if (v == mkldnn_OIw8i8o) return "OIw8i8o";
+    if (v == mkldnn_OIw8o16i2o) return "OIw8o16i2o";
+    if (v == mkldnn_OIw8o8i) return "OIw8o8i";
+    if (v == mkldnn_Owi16o) return "Owi16o";
+    if (v == mkldnn_Owi4o) return "Owi4o";
+    if (v == mkldnn_Owi8o) return "Owi8o";
+    if (v == mkldnn_IOhw16o16i) return "IOhw16o16i";
+    if (v == mkldnn_Ohwi16o) return "Ohwi16o";
+    if (v == mkldnn_Ohwi4o) return "Ohwi4o";
+    if (v == mkldnn_Ohwi8o) return "Ohwi8o";
+    if (v == mkldnn_OIhw16i16o) return "OIhw16i16o";
+    if (v == mkldnn_OIhw16o16i) return "OIhw16o16i";
+    if (v == mkldnn_Oihw16o) return "Oihw16o";
+    if (v == mkldnn_OIhw4i16o4i) return "OIhw4i16o4i";
+    if (v == mkldnn_OIhw4i4o) return "OIhw4i4o";
+    if (v == mkldnn_Oihw4o) return "Oihw4o";
+    if (v == mkldnn_OIhw8i16o2i) return "OIhw8i16o2i";
+    if (v == mkldnn_OIhw8i8o) return "OIhw8i8o";
+    if (v == mkldnn_OIhw8o16i2o) return "OIhw8o16i2o";
+    if (v == mkldnn_OIhw8o8i) return "OIhw8o8i";
+    if (v == mkldnn_Odhwi16o) return "Odhwi16o";
+    if (v == mkldnn_Odhwi4o) return "Odhwi4o";
+    if (v == mkldnn_Odhwi8o) return "Odhwi8o";
+    if (v == mkldnn_OIdhw16i16o) return "OIdhw16i16o";
+    if (v == mkldnn_OIdhw16o16i) return "OIdhw16o16i";
+    if (v == mkldnn_Oidhw16o) return "Oidhw16o";
+    if (v == mkldnn_OIdhw4i4o) return "OIdhw4i4o";
+    if (v == mkldnn_Oidhw4o) return "Oidhw4o";
+    if (v == mkldnn_OIdhw8i16o2i) return "OIdhw8i16o2i";
+    if (v == mkldnn_OIdhw8i8o) return "OIdhw8i8o";
+    if (v == mkldnn_OIdhw8o8i) return "OIdhw8o8i";
+    if (v == mkldnn_Goiw16g) return "Goiw16g";
+    if (v == mkldnn_gIOw16o16i) return "gIOw16o16i";
+    if (v == mkldnn_gOIw16i16o) return "gOIw16i16o";
+    if (v == mkldnn_gOIw16o16i) return "gOIw16o16i";
+    if (v == mkldnn_gOiw16o) return "gOiw16o";
+    if (v == mkldnn_gOIw4i16o4i) return "gOIw4i16o4i";
+    if (v == mkldnn_gOIw4i4o) return "gOIw4i4o";
+    if (v == mkldnn_gOiw4o) return "gOiw4o";
+    if (v == mkldnn_gOIw8i16o2i) return "gOIw8i16o2i";
+    if (v == mkldnn_gOIw8i8o) return "gOIw8i8o";
+    if (v == mkldnn_gOIw8o16i2o) return "gOIw8o16i2o";
+    if (v == mkldnn_gOIw8o8i) return "gOIw8o8i";
+    if (v == mkldnn_gOwi16o) return "gOwi16o";
+    if (v == mkldnn_gOwi4o) return "gOwi4o";
+    if (v == mkldnn_gOwi8o) return "gOwi8o";
+    if (v == mkldnn_gIOhw16o16i) return "gIOhw16o16i";
+    if (v == mkldnn_gOhwi16o) return "gOhwi16o";
+    if (v == mkldnn_gOhwi4o) return "gOhwi4o";
+    if (v == mkldnn_gOhwi8o) return "gOhwi8o";
+    if (v == mkldnn_Goihw16g) return "Goihw16g";
+    if (v == mkldnn_gOIhw16i16o) return "gOIhw16i16o";
+    if (v == mkldnn_gOIhw16o16i) return "gOIhw16o16i";
+    if (v == mkldnn_gOihw16o) return "gOihw16o";
+    if (v == mkldnn_gOIhw2i8o4i) return "gOIhw2i8o4i";
+    if (v == mkldnn_gOIhw4i16o4i) return "gOIhw4i16o4i";
+    if (v == mkldnn_gOIhw4i4o) return "gOIhw4i4o";
+    if (v == mkldnn_gOIhw4o4i) return "gOIhw4o4i";
+    if (v == mkldnn_gOihw4o) return "gOihw4o";
+    if (v == mkldnn_Goihw8g) return "Goihw8g";
+    if (v == mkldnn_gOIhw8i16o2i) return "gOIhw8i16o2i";
+    if (v == mkldnn_gOIhw8i8o) return "gOIhw8i8o";
+    if (v == mkldnn_gOIhw8o16i2o) return "gOIhw8o16i2o";
+    if (v == mkldnn_gOIhw8o8i) return "gOIhw8o8i";
+    if (v == mkldnn_gOdhwi16o) return "gOdhwi16o";
+    if (v == mkldnn_gOdhwi4o) return "gOdhwi4o";
+    if (v == mkldnn_gOdhwi8o) return "gOdhwi8o";
+    if (v == mkldnn_gOIdhw16i16o) return "gOIdhw16i16o";
+    if (v == mkldnn_gOIdhw16o16i) return "gOIdhw16o16i";
+    if (v == mkldnn_gOidhw16o) return "gOidhw16o";
+    if (v == mkldnn_gOIdhw4i4o) return "gOIdhw4i4o";
+    if (v == mkldnn_gOidhw4o) return "gOidhw4o";
+    if (v == mkldnn_gOIdhw8i16o2i) return "gOIdhw8i16o2i";
+    if (v == mkldnn_gOIdhw8i8o) return "gOIdhw8i8o";
+    if (v == mkldnn_gOIdhw8o8i) return "gOIdhw8o8i";
+    assert(!"unknown fmt_tag");
+    return "unknown fmt_tag";
+}
+
+const char *mkldnn_prop_kind2str(mkldnn_prop_kind_t v) {
+    if (v == mkldnn_prop_kind_undef) return "undef";
+    if (v == mkldnn_forward_training) return "forward_training";
+    if (v == mkldnn_forward_inference) return "forward_inference";
+    if (v == mkldnn_forward_scoring) return "forward_scoring";
+    if (v == mkldnn_forward) return "forward";
+    if (v == mkldnn_backward) return "backward";
+    if (v == mkldnn_backward_data) return "backward_data";
+    if (v == mkldnn_backward_weights) return "backward_weights";
+    if (v == mkldnn_backward_bias) return "backward_bias";
+    assert(!"unknown prop_kind");
+    return "unknown prop_kind";
+}
+
+const char *mkldnn_prim_kind2str(mkldnn_primitive_kind_t v) {
+    if (v == mkldnn_undefined_primitive) return "undef";
+    if (v == mkldnn_reorder) return "reorder";
+    if (v == mkldnn_shuffle) return "shuffle";
+    if (v == mkldnn_concat) return "concat";
+    if (v == mkldnn_sum) return "sum";
+    if (v == mkldnn_convolution) return "convolution";
+    if (v == mkldnn_deconvolution) return "deconvolution";
+    if (v == mkldnn_eltwise) return "eltwise";
+    if (v == mkldnn_softmax) return "softmax";
+    if (v == mkldnn_pooling) return "pooling";
+    if (v == mkldnn_lrn) return "lrn";
+    if (v == mkldnn_batch_normalization) return "batch_normalization";
+    if (v == mkldnn_inner_product) return "inner_product";
+    if (v == mkldnn_rnn) return "rnn";
+    assert(!"unknown prim_kind");
+    return "unknown prim_kind";
+}
+
+const char *mkldnn_alg_kind2str(mkldnn_alg_kind_t v) {
+    if (v == mkldnn_alg_kind_undef) return "undef";
+    if (v == mkldnn_convolution_direct) return "convolution_direct";
+    if (v == mkldnn_convolution_winograd) return "convolution_winograd";
+    if (v == mkldnn_convolution_auto) return "convolution_auto";
+    if (v == mkldnn_deconvolution_direct) return "deconvolution_direct";
+    if (v == mkldnn_deconvolution_winograd) return "deconvolution_winograd";
+    if (v == mkldnn_eltwise_relu) return "eltwise_relu";
+    if (v == mkldnn_eltwise_tanh) return "eltwise_tanh";
+    if (v == mkldnn_eltwise_elu) return "eltwise_elu";
+    if (v == mkldnn_eltwise_square) return "eltwise_square";
+    if (v == mkldnn_eltwise_abs) return "eltwise_abs";
+    if (v == mkldnn_eltwise_sqrt) return "eltwise_sqrt";
+    if (v == mkldnn_eltwise_linear) return "eltwise_linear";
+    if (v == mkldnn_eltwise_bounded_relu) return "eltwise_bounded_relu";
+    if (v == mkldnn_eltwise_soft_relu) return "eltwise_soft_relu";
+    if (v == mkldnn_eltwise_logistic) return "eltwise_logistic";
+    if (v == mkldnn_pooling_max) return "pooling_max";
+    if (v == mkldnn_pooling_avg_include_padding) return "pooling_avg_include_padding";
+    if (v == mkldnn_pooling_avg_exclude_padding) return "pooling_avg_exclude_padding";
+    if (v == mkldnn_pooling_avg) return "pooling_avg";
+    if (v == mkldnn_lrn_across_channels) return "lrn_across_channels";
+    if (v == mkldnn_lrn_within_channel) return "lrn_within_channel";
+    if (v == mkldnn_vanilla_rnn) return "vanilla_rnn";
+    if (v == mkldnn_vanilla_lstm) return "vanilla_lstm";
+    if (v == mkldnn_vanilla_gru) return "vanilla_gru";
+    if (v == mkldnn_gru_linear_before_reset) return "gru_linear_before_reset";
+    assert(!"unknown alg_kind");
+    return "unknown alg_kind";
+}
+
+const char *mkldnn_rnn_direction2str(mkldnn_rnn_direction_t v) {
+    if (v == mkldnn_unidirectional_left2right) return "unidirectional_left2right";
+    if (v == mkldnn_unidirectional_right2left) return "unidirectional_right2left";
+    if (v == mkldnn_bidirectional_concat) return "bidirectional_concat";
+    if (v == mkldnn_bidirectional_sum) return "bidirectional_sum";
+    if (v == mkldnn_unidirectional) return "unidirectional";
+    assert(!"unknown rnn_direction");
+    return "unknown rnn_direction";
+}

+ 115 - 0
thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread.hpp

@@ -0,0 +1,115 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MKLDNN_THREAD_HPP
+#define MKLDNN_THREAD_HPP
+
+#include "utils.hpp"
+#include "z_magic.hpp"
+
+#define MKLDNN_THR_SEQ 0
+#define MKLDNN_THR_OMP 1
+#define MKLDNN_THR_TBB 2
+
+/* Ideally this condition below should never happen (if the library is built
+ * using regular cmake). For the 3rd-party projects that build the library
+ * from the sources on their own try to guess the right threading... */
+#if !defined(MKLDNN_THR)
+#   define MKLDNN_THR MKLDNN_THR_TBB
+#endif
+
+#if MKLDNN_THR == MKLDNN_THR_SEQ
+#define MKLDNN_THR_SYNC 1
+inline int mkldnn_get_max_threads() { return 1; }
+inline int mkldnn_get_num_threads() { return 1; }
+inline int mkldnn_get_thread_num() { return 0; }
+inline int mkldnn_in_parallel() { return 0; }
+inline void mkldnn_thr_barrier() {}
+
+#define PRAGMA_OMP(...)
+
+#elif MKLDNN_THR == MKLDNN_THR_OMP
+#include <omp.h>
+#define MKLDNN_THR_SYNC 1
+
+inline int mkldnn_get_max_threads() { return omp_get_max_threads(); }
+inline int mkldnn_get_num_threads() { return omp_get_num_threads(); }
+inline int mkldnn_get_thread_num() { return omp_get_thread_num(); }
+inline int mkldnn_in_parallel() { return omp_in_parallel(); }
+inline void mkldnn_thr_barrier() {
+#   pragma omp barrier
+}
+
+#define PRAGMA_OMP(...) PRAGMA_MACRO(CHAIN2(omp, __VA_ARGS__))
+
+#elif MKLDNN_THR == MKLDNN_THR_TBB
+#include "tbb/task_arena.h"
+#include "tbb/parallel_for.h"
+#define MKLDNN_THR_SYNC 0
+
+inline int mkldnn_get_max_threads()
+{ return tbb::this_task_arena::max_concurrency(); }
+inline int mkldnn_get_num_threads() { return mkldnn_get_max_threads(); }
+inline int mkldnn_get_thread_num()
+{ return tbb::this_task_arena::current_thread_index(); }
+inline int mkldnn_in_parallel() { return 0; }
+inline void mkldnn_thr_barrier() { assert(!"no barrier in TBB"); }
+
+#define PRAGMA_OMP(...)
+
+#endif
+
+/* MSVC still supports omp 2.0 only */
+#if defined(_MSC_VER) && !defined(__clang__) && !defined(__INTEL_COMPILER)
+#   define collapse(x)
+#   define PRAGMA_OMP_SIMD(...)
+#else
+#   define PRAGMA_OMP_SIMD(...) PRAGMA_MACRO(CHAIN2(omp, simd __VA_ARGS__))
+#endif // defined(_MSC_VER) && !defined(__INTEL_COMPILER)
+
+namespace mkldnn {
+namespace impl {
+
+inline bool mkldnn_thr_syncable() { return MKLDNN_THR_SYNC == 1; }
+
+template <typename T, typename U>
+inline void balance211(T n, U team, U tid, T &n_start, T &n_end) {
+    T n_min = 1;
+    T &n_my = n_end;
+    if (team <= 1 || n == 0) {
+        n_start = 0;
+        n_my = n;
+    } else if (n_min == 1) {
+        // team = T1 + T2
+        // n = T1*n1 + T2*n2  (n1 - n2 = 1)
+        T n1 = utils::div_up(n, (T)team);
+        T n2 = n1 - 1;
+        T T1 = n - n2 * (T)team;
+        n_my = (T)tid < T1 ? n1 : n2;
+        n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2;
+    }
+
+    n_end += n_start;
+}
+
+} // namespace impl
+} // namespace mkldnn
+
+#include "mkldnn_thread_parallel_nd.hpp"
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 277 - 0
thirdparty/oidn/mkl-dnn/src/common/mkldnn_thread_parallel_nd.hpp

@@ -0,0 +1,277 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MKLDNN_THREAD_PARALLEL_ND_HPP
+#define MKLDNN_THREAD_PARALLEL_ND_HPP
+
+/* This header must be included by mkldnn_thread.hpp only */
+
+/* Functions:
+ *  - parallel(nthr, f)              - executes f in parallel using at most
+ *                                     nthr threads. If nthr equals 0
+ *                                     mkldnn_get_max_threads() threads is
+ *                                     used
+ *  - for_nd(ithr, nthr, dims..., f) - multidimensional for loop for already
+ *                                     created threads
+ *  - parallel_nd(dims..., f)        - creates a parallel section and then
+ *                                     calls for_nd
+ *  - parallel_nd_in_omp(dims..., f) - queries current nthr and ithr and then
+ *                                     calls for_nd (mostly for convenience)
+ */
+
+namespace mkldnn {
+namespace impl {
+
+/* general parallelization */
+template <typename F>
+void parallel(int nthr, F f) {
+    if (nthr == 0) nthr = mkldnn_get_max_threads();
+#if MKLDNN_THR == MKLDNN_THR_SEQ
+    assert(nthr == 1);
+    f(0, 1);
+#elif MKLDNN_THR == MKLDNN_THR_OMP
+    if (nthr == 1) { f(0, 1); return; }
+#   pragma omp parallel num_threads(nthr)
+    f(mkldnn_get_thread_num(), mkldnn_get_num_threads());
+#elif MKLDNN_THR == MKLDNN_THR_TBB
+    if (nthr == 1) { f(0, 1); return; }
+    tbb::parallel_for(0, nthr, [&](int ithr) { f(ithr, nthr); }, tbb::static_partitioner());
+#endif
+}
+
+/* for_nd section */
+
+template <typename T0, typename F>
+void for_nd(const int ithr, const int nthr, const T0 &D0, F f) {
+    T0 start{0}, end{0};
+    balance211(D0, nthr, ithr, start, end);
+    for (T0 d0 = start; d0 < end; ++d0) f(d0);
+}
+
+template <typename T0, typename T1, typename F>
+void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1, F f) {
+    const size_t work_amount = (size_t)D0 * D1;
+    if (work_amount == 0) return;
+    size_t start{0}, end{0};
+    balance211(work_amount, nthr, ithr, start, end);
+
+    T0 d0{0}; T1 d1{0};
+    utils::nd_iterator_init(start, d0, D0, d1, D1);
+    for (size_t iwork = start; iwork < end; ++iwork) {
+        f(d0, d1);
+        utils::nd_iterator_step(d0, D0, d1, D1);
+    }
+}
+
+template <typename T0, typename T1, typename T2, typename F>
+void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
+        const T2 &D2, F f) {
+    const size_t work_amount = (size_t)D0 * D1 * D2;
+    if (work_amount == 0) return;
+    size_t start{0}, end{0};
+    balance211(work_amount, nthr, ithr, start, end);
+
+    T0 d0{0}; T1 d1{0}; T2 d2{0};
+    utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2);
+    for (size_t iwork = start; iwork < end; ++iwork) {
+        f(d0, d1, d2);
+        utils::nd_iterator_step(d0, D0, d1, D1, d2, D2);
+    }
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename F>
+void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
+        const T2 &D2, const T3 &D3, F f) {
+    const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
+    if (work_amount == 0) return;
+    size_t start{0}, end{0};
+    balance211(work_amount, nthr, ithr, start, end);
+
+    T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0};
+    utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3);
+    for (size_t iwork = start; iwork < end; ++iwork) {
+        f(d0, d1, d2, d3);
+        utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
+    }
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename T4,
+         typename F>
+void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
+        const T2 &D2, const T3 &D3, const T4 &D4, F f) {
+    const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4;
+    if (work_amount == 0) return;
+    size_t start{0}, end{0};
+    balance211(work_amount, nthr, ithr, start, end);
+
+    T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0};
+    utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
+    for (size_t iwork = start; iwork < end; ++iwork) {
+        f(d0, d1, d2, d3, d4);
+        utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
+    }
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename T4,
+         typename T5, typename F>
+void for_nd(const int ithr, const int nthr, const T0 &D0, const T1 &D1,
+        const T2 &D2, const T3 &D3, const T4 &D4, const T5 &D5, F f) {
+    const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
+    if (work_amount == 0) return;
+    size_t start{0}, end{0};
+    balance211(work_amount, nthr, ithr, start, end);
+
+    T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0};
+    utils::nd_iterator_init(start, d0, D0, d1, D1, d2, D2, d3, D3, d4, D4,
+            d5, D5);
+    for (size_t iwork = start; iwork < end; ++iwork) {
+        f(d0, d1, d2, d3, d4, d5);
+        utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
+    }
+}
+
+// Skip a lambda function in the parameter pack.
+template <typename T>
+constexpr size_t get_work_amount(const T &v) { return 1; }
+template <typename T, typename ...Args>
+constexpr size_t get_work_amount(const T &v, Args &&...args)
+{ return (size_t)v * get_work_amount(utils::forward<Args>(args)...); }
+
+/* parallel_nd and parallel_nd_in_omp section */
+
+#if MKLDNN_THR != MKLDNN_THR_TBB
+template <typename ...Args>
+void parallel_nd(Args &&...args) {
+#if MKLDNN_THR == MKLDNN_THR_SEQ
+    for_nd(0, 1, utils::forward<Args>(args)...);
+#elif MKLDNN_THR == MKLDNN_THR_OMP
+    const bool do_parallel = get_work_amount(utils::forward<Args>(args)...) > 1;
+#   pragma omp parallel if (do_parallel)
+    {
+        const int nthr = !do_parallel ? 1 : mkldnn_get_num_threads();
+        const int ithr = !do_parallel ? 0 : mkldnn_get_thread_num();
+        for_nd(ithr, nthr, utils::forward<Args>(args)...);
+    }
+#endif
+}
+#else // MKLDNN_THR != MKLDNN_THR_TBB
+
+// gcc 4.8 has a bug with passing parameter pack to lambdas.
+// So have to explicitly instantiate all the cases.
+
+template <typename T0, typename F>
+void parallel_nd(const T0 &D0, F f) {
+    const size_t work_amount = (size_t)D0;
+    if (work_amount == 0) return;
+    tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
+        for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
+            f(T0(iwork));
+        }
+    }, tbb::static_partitioner());
+}
+
+template <typename T0, typename T1, typename F>
+void parallel_nd(const T0 &D0, const T1 &D1, F f) {
+    const size_t work_amount = (size_t)D0 * D1;
+    if (work_amount == 0) return;
+    tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
+        T0 d0{0}; T1 d1{0};
+        utils::nd_iterator_init(r.begin(), d0, D0, d1, D1);
+        for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
+            f(d0, d1);
+            utils::nd_iterator_step(d0, D0, d1, D1);
+        }
+    }, tbb::static_partitioner());
+}
+
+template <typename T0, typename T1, typename T2, typename F>
+void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, F f) {
+    const size_t work_amount = (size_t)D0 * D1 * D2;
+    if (work_amount == 0) return;
+    tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
+        T0 d0{0}; T1 d1{0}; T2 d2{0};
+        utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2);
+        for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
+            f(d0, d1, d2);
+            utils::nd_iterator_step(d0, D0, d1, D1, d2, D2);
+        }
+    }, tbb::static_partitioner());
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename F>
+void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3, F f) {
+    const size_t work_amount = (size_t)D0 * D1 * D2 * D3;
+    if (work_amount == 0) return;
+    tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
+        T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0};
+        utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3);
+        for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
+            f(d0, d1, d2, d3);
+            utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3);
+        }
+    }, tbb::static_partitioner());
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename T4,
+         typename F>
+void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
+        const T4 &D4, F f) {
+    const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4;
+    if (work_amount == 0) return;
+    tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
+        T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0};
+        utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
+        for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
+            f(d0, d1, d2, d3, d4);
+            utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4);
+        }
+    }, tbb::static_partitioner());
+}
+
+template <typename T0, typename T1, typename T2, typename T3, typename T4,
+         typename T5, typename F>
+void parallel_nd(const T0 &D0, const T1 &D1, const T2 &D2, const T3 &D3,
+        const T4 &D4, const T5 &D5, F f) {
+    const size_t work_amount = (size_t)D0 * D1 * D2 * D3 * D4 * D5;
+    if (work_amount == 0) return;
+    tbb::parallel_for(tbb::blocked_range<size_t>(0, work_amount), [&](const tbb::blocked_range<size_t>& r) {
+        T0 d0{0}; T1 d1{0}; T2 d2{0}; T3 d3{0}; T4 d4{0}; T5 d5{0};
+        utils::nd_iterator_init(r.begin(), d0, D0, d1, D1, d2, D2, d3, D3, d4, D4,
+                d5, D5);
+        for (size_t iwork = r.begin(); iwork != r.end(); ++iwork) {
+            f(d0, d1, d2, d3, d4, d5);
+            utils::nd_iterator_step(d0, D0, d1, D1, d2, D2, d3, D3, d4, D4, d5, D5);
+        }
+    }, tbb::static_partitioner());
+}
+#endif
+
+template <typename ...Args>
+void parallel_nd_in_omp(Args &&...args) {
+#if MKLDNN_THR == MKLDNN_THR_SEQ
+    for_nd(0, 1, utils::forward<Args>(args)...);
+#elif MKLDNN_THR == MKLDNN_THR_OMP
+    for_nd(mkldnn_get_thread_num(), mkldnn_get_num_threads(),
+            utils::forward<Args>(args)...);
+#elif MKLDNN_THR == MKLDNN_THR_TBB
+    assert(!"unsupported parallel_nd_in_omp()");
+#endif
+}
+
+} // namespace impl
+} // namespace mkldnn
+
+#endif

+ 77 - 0
thirdparty/oidn/mkl-dnn/src/common/mkldnn_traits.hpp

@@ -0,0 +1,77 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef MKLDNN_TRAITS_HPP
+#define MKLDNN_TRAITS_HPP
+
+#include <assert.h>
+#include <stdint.h>
+
+#include "mkldnn.h"
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "utils.hpp"
+#include "z_magic.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+template <data_type_t> struct prec_traits {}; /* ::type -> float */
+template <typename> struct data_traits {}; /* ::data_type -> f32 */
+template <int> struct typesize_traits {}; /* ::data_type_size -> f32 */
+template <primitive_kind_t> struct pkind_traits {}; /* ::desc_type, ::query_d */
+
+template <> struct prec_traits<data_type::f32> { typedef float type; };
+template <> struct prec_traits<data_type::s32> { typedef int32_t type; };
+template <> struct prec_traits<data_type::s8> { typedef int8_t type; };
+template <> struct prec_traits<data_type::u8> { typedef uint8_t type; };
+
+template <> struct data_traits<float>
+{ static constexpr data_type_t data_type = data_type::f32; };
+template <> struct data_traits<int32_t>
+{ static constexpr data_type_t data_type = data_type::s32; };
+template <> struct data_traits<int8_t>
+{ static constexpr data_type_t data_type = data_type::s8; };
+template <> struct data_traits<uint8_t>
+{ static constexpr data_type_t data_type = data_type::u8; };
+
+template <> struct typesize_traits<4> { typedef float type; };
+template <> struct typesize_traits<2> { typedef int16_t type; };
+template <> struct typesize_traits<1> { typedef uint8_t type; };
+
+#define PKIND_TRAITS_INST(op) \
+template <> struct pkind_traits<primitive_kind::op> { \
+    typedef CONCAT2(op, _desc_t) desc_type; \
+    static constexpr query_t query_d = query::CONCAT2(op, _d); \
+}
+PKIND_TRAITS_INST(convolution);
+PKIND_TRAITS_INST(deconvolution);
+PKIND_TRAITS_INST(shuffle);
+PKIND_TRAITS_INST(eltwise);
+PKIND_TRAITS_INST(softmax);
+PKIND_TRAITS_INST(pooling);
+PKIND_TRAITS_INST(lrn);
+PKIND_TRAITS_INST(batch_normalization);
+PKIND_TRAITS_INST(inner_product);
+PKIND_TRAITS_INST(rnn);
+#undef PKIND_TRAITS_INST
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 193 - 0
thirdparty/oidn/mkl-dnn/src/common/nstl.hpp

@@ -0,0 +1,193 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef NSTL_HPP
+#define NSTL_HPP
+
+#include <stdint.h>
+#include <limits.h>
+#include <float.h>
+
+#include <vector>
+#include <map>
+
+#include "z_magic.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+void *malloc(size_t size, int alignment);
+void free(void *p);
+
+struct c_compatible {
+    enum { default_alignment = 64 };
+    static void *operator new(size_t sz) {
+        return malloc(sz, default_alignment);
+    }
+    static void *operator new(size_t sz, void *p) { UNUSED(sz); return p; }
+    static void *operator new[](size_t sz) {
+        return malloc(sz, default_alignment);
+    }
+    static void operator delete(void *p) { free(p); }
+    static void operator delete[](void *p) { free(p); }
+};
+
+namespace nstl {
+
+template<typename T>
+inline const T abs(const T& a) {
+    return a >= 0 ? a : -a;
+}
+
+template<typename T>
+inline const T& max(const T& a, const T& b) {
+    return a > b ? a : b;
+}
+
+template<typename T>
+inline const T& min(const T& a, const T& b) {
+    return a < b ? a : b;
+}
+
+template<typename T> void swap(T& t1, T& t2) {
+    T tmp(t1);
+    t1 = t2;
+    t2 = tmp;
+}
+
+// Rationale: MKL-DNN needs numeric limits implementation that does not
+// generate dependencies on C++ run-time libraries.
+
+template<typename T> struct numeric_limits;
+
+template<> struct numeric_limits<float> {
+    static constexpr float lowest() { return -FLT_MAX; }
+    static constexpr float max() { return FLT_MAX; }
+};
+
+template<> struct numeric_limits<int32_t> {
+    static constexpr int lowest() { return INT32_MIN; }
+    static constexpr int max() { return INT32_MAX; }
+};
+
+template<> struct numeric_limits<int16_t> {
+    static constexpr int16_t lowest() { return INT16_MIN; }
+    static constexpr int16_t max() { return INT16_MAX; }
+};
+
+template<> struct numeric_limits<int8_t> {
+    static constexpr int8_t lowest() { return INT8_MIN; }
+    static constexpr int8_t max() { return INT8_MAX; }
+};
+
+template<> struct numeric_limits<uint8_t> {
+    static constexpr uint8_t lowest() { return 0; }
+    static constexpr uint8_t max() { return UINT8_MAX; }
+};
+
+template<typename T> struct is_integral
+{ static constexpr bool value = false; };
+template<> struct is_integral<int32_t> { static constexpr bool value = true; };
+template<> struct is_integral<int16_t> { static constexpr bool value = true; };
+template<> struct is_integral<int8_t> { static constexpr bool value = true; };
+template<> struct is_integral<uint8_t> { static constexpr bool value = true; };
+
+template <typename T, typename U> struct is_same
+{ static constexpr bool value = false; };
+template <typename T> struct is_same<T, T>
+{ static constexpr bool value = true; };
+
+// Rationale: MKL-DNN needs container implementations that do not generate
+// dependencies on C++ run-time libraries.
+//
+// Implementation philosophy: caller is responsible to check if the operation
+// is valid. The only functions that have to return status are those that
+// depend on memory allocation or similar operations.
+//
+// This means that e.g. an operator [] does not have to check for boundaries.
+// The caller should have checked the boundaries. If it did not we crash and
+// burn: this is a bug in MKL-DNN and throwing an exception would not have been
+// recoverable.
+//
+// On the other hand, insert() or resize() or a similar operation needs to
+// return a status because the outcome depends on factors external to the
+// caller. The situation is probably also not recoverable also, but MKL-DNN
+// needs to be nice and report "out of memory" to the users.
+
+enum nstl_status_t {
+    success = 0,
+    out_of_memory
+};
+
+template <typename T> class vector: public c_compatible {
+private:
+    std::vector<T> _impl;
+public:
+    typedef typename std::vector<T>::iterator iterator;
+    typedef typename std::vector<T>::const_iterator const_iterator;
+    typedef typename std::vector<T>::size_type size_type;
+    vector() {}
+    vector(size_type n): _impl(n) {}
+    vector(size_type n, const T &value): _impl(n, value) {}
+    template <typename input_iterator>
+    vector(input_iterator first, input_iterator last): _impl(first, last) {}
+    ~vector() {}
+    size_type size() const { return _impl.size(); }
+    T& operator[] (size_type i) { return _impl[i]; }
+    const T& operator[] (size_type i) const { return _impl[i]; }
+    iterator begin() { return _impl.begin(); }
+    const_iterator begin() const { return _impl.begin(); }
+    iterator end() { return _impl.end(); }
+    const_iterator end() const { return _impl.end(); }
+    template <typename input_iterator>
+    nstl_status_t insert(iterator pos, input_iterator begin, input_iterator end)
+    {
+        _impl.insert(pos, begin, end);
+        return success;
+    }
+    void clear() { _impl.clear(); }
+    void push_back(const T& t) { _impl.push_back(t); }
+    void resize(size_type count) { _impl.resize(count); }
+    void reserve(size_type count) { _impl.reserve(count); }
+};
+
+template <typename Key, typename T> class map: public c_compatible {
+private:
+    std::map<Key, T> _impl;
+public:
+    typedef typename std::map<Key, T>::iterator iterator;
+    typedef typename std::map<Key, T>::const_iterator const_iterator;
+    typedef typename std::map<Key, T>::size_type size_type;
+    map() {}
+    ~map() {}
+    size_type size() const { return _impl.size(); }
+    T& operator[](const Key &k) { return _impl[k]; }
+    const T& operator[](const Key &k) const { return _impl[k]; }
+    iterator begin() { return _impl.begin(); }
+    const_iterator begin() const { return _impl.begin(); }
+    iterator end() { return _impl.end(); }
+    const_iterator end() const { return _impl.end(); }
+    template <typename input_iterator>
+    void clear() { _impl.clear(); }
+};
+
+}
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 114 - 0
thirdparty/oidn/mkl-dnn/src/common/pooling.cpp

@@ -0,0 +1,114 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::prop_kind;
+using namespace mkldnn::impl::alg_kind;
+using namespace mkldnn::impl::types;
+
+namespace {
+status_t pooling_desc_init(pooling_desc_t *pool_desc,
+        prop_kind_t prop_kind, alg_kind_t alg_kind,
+        const memory_desc_t *src_desc, const memory_desc_t *dst_desc,
+        const dims_t strides, const dims_t kernel, const dims_t padding_l,
+        const dims_t padding_r, padding_kind_t padding_kind) {
+    bool args_ok = true
+        && !any_null(pool_desc, src_desc, dst_desc, strides, kernel, padding_l)
+        && one_of(alg_kind, pooling_max,
+                pooling_avg_include_padding,
+                pooling_avg_exclude_padding)
+        && one_of(padding_kind, padding_kind::padding_zero);
+    if (!args_ok) return invalid_arguments;
+
+    if (padding_r == nullptr) padding_r = padding_l;
+
+    auto pd = pooling_desc_t();
+    pd.primitive_kind = primitive_kind::pooling;
+    pd.prop_kind = prop_kind;
+    pd.alg_kind = alg_kind;
+    pd.src_desc.ndims = src_desc->ndims;
+
+    const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
+
+    pd.diff_src_desc = pd.src_desc = zero_md();
+    pd.diff_dst_desc = pd.dst_desc = zero_md();
+
+    (is_fwd ? pd.src_desc : pd.diff_src_desc) = *src_desc;
+    (is_fwd ? pd.dst_desc : pd.diff_dst_desc) = *dst_desc;
+
+    int sp_dims = src_desc->ndims - 2;
+    utils::array_copy(pd.strides, strides, sp_dims);
+    utils::array_copy(pd.kernel, kernel, sp_dims);
+    utils::array_copy(pd.padding[0], padding_l, sp_dims);
+    utils::array_copy(pd.padding[1], padding_r, sp_dims);
+
+    pd.padding_kind = padding_kind;
+    if (one_of(alg_kind, pooling_max, pooling_avg_include_padding,
+                pooling_avg_exclude_padding)) {
+        pd.accum_data_type = types::default_accum_data_type(
+                src_desc->data_type, dst_desc->data_type);
+    } else {
+        pd.accum_data_type = dst_desc->data_type;
+    }
+
+    bool consistency = true
+        && utils::one_of(src_desc->ndims, 4, 5)
+        && utils::one_of(dst_desc->ndims, 4, 5)
+        && src_desc->dims[0] == dst_desc->dims[0]
+        && src_desc->dims[1] == dst_desc->dims[1];
+    for (int i = 2; i < src_desc->ndims; ++i)
+        consistency = consistency && (
+                (src_desc->dims[i] - kernel[i - 2] + padding_l[i - 2]
+                 + padding_r[i - 2]) / strides[i - 2] + 1
+                == dst_desc->dims[i]);
+    if (!consistency) return invalid_arguments;
+
+    *pool_desc = pd;
+    return success;
+}
+}
+
+status_t mkldnn_pooling_forward_desc_init(pooling_desc_t *pool_desc,
+        prop_kind_t prop_kind, alg_kind_t alg_kind,
+        const memory_desc_t *src_desc, const memory_desc_t *dst_desc,
+        const dims_t strides, const dims_t kernel, const dims_t padding_l,
+        const dims_t padding_r, padding_kind_t padding_kind) {
+    if (!one_of(prop_kind, forward_training, forward_inference))
+        return invalid_arguments;
+    return pooling_desc_init(pool_desc, prop_kind, alg_kind, src_desc,
+            dst_desc, strides, kernel, padding_l, padding_r, padding_kind);
+}
+
+status_t mkldnn_pooling_backward_desc_init(pooling_desc_t *pool_desc,
+        alg_kind_t alg_kind, const memory_desc_t *diff_src_desc,
+        const memory_desc_t *diff_dst_desc, const dims_t strides,
+        const dims_t kernel, const dims_t padding_l, const dims_t padding_r,
+        padding_kind_t padding_kind) {
+    return pooling_desc_init(pool_desc, prop_kind::backward_data, alg_kind,
+            diff_src_desc, diff_dst_desc, strides, kernel, padding_l,
+            padding_r, padding_kind);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 238 - 0
thirdparty/oidn/mkl-dnn/src/common/pooling_pd.hpp

@@ -0,0 +1,238 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef POOLING_PD_HPP
+#define POOLING_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+#include "type_helpers.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct pooling_fwd_pd_t;
+
+struct pooling_pd_t: public primitive_desc_t {
+    static constexpr auto base_pkind = primitive_kind::pooling;
+
+    pooling_pd_t(engine_t *engine,
+            const pooling_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const pooling_fwd_pd_t *hint_fwd_pd)
+        : primitive_desc_t(engine, attr, base_pkind)
+        , desc_(*adesc)
+        , hint_fwd_pd_(hint_fwd_pd)
+        , ws_md_()
+    {}
+
+    const pooling_desc_t *desc() const { return &desc_; }
+    virtual const op_desc_t *op_desc() const override
+    { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+    virtual void init_info() override { impl::init_info(this, this->info_); }
+
+    virtual status_t query(query_t what, int idx, void *result) const override {
+        switch (what) {
+        case query::pooling_d:
+            *(const pooling_desc_t**)result = desc(); break;
+        default: return primitive_desc_t::query(what, idx, result);
+        }
+        return status::success;
+    }
+
+    /* common pooling aux functions */
+
+    dim_t MB() const { return src_desc().dims[0]; }
+    dim_t C() const { return src_desc().dims[1]; }
+
+    dim_t ID() const { return ndims() >= 5 ? src_desc().dims[ndims() - 3] : 1; }
+    dim_t IH() const { return ndims() >= 4 ? src_desc().dims[ndims() - 2] : 1; }
+    dim_t IW() const { return src_desc().dims[ndims() - 1]; }
+
+    dim_t OD() const { return ndims() >= 5 ? dst_desc().dims[ndims() - 3] : 1; }
+    dim_t OH() const { return ndims() >= 4 ? dst_desc().dims[ndims() - 2] : 1; }
+    dim_t OW() const { return dst_desc().dims[ndims() - 1]; }
+
+    dim_t KD() const { return ndims() >= 5 ? desc_.kernel[ndims() - 5] : 1; }
+    dim_t KH() const { return ndims() >= 4 ? desc_.kernel[ndims() - 4] : 1; }
+    dim_t KW() const { return desc_.kernel[ndims() - 3]; }
+
+    dim_t KSD() const { return ndims() >= 5 ? desc_.strides[ndims() - 5] : 1; }
+    dim_t KSH() const { return ndims() >= 4 ? desc_.strides[ndims() - 4] : 1; }
+    dim_t KSW() const { return desc_.strides[ndims() - 3]; }
+
+    dim_t padFront() const
+    { return ndims() >= 5 ? desc_.padding[0][ndims() - 5] : 0; }
+    dim_t padBack() const
+    { return ndims() >= 5 ? desc_.padding[1][ndims() - 5] : 0; }
+    dim_t padT() const
+    { return ndims() >= 4 ? desc_.padding[0][ndims() - 4] : 0; }
+    dim_t padB() const
+    { return ndims() >= 4 ? desc_.padding[1][ndims() - 4] : 0; }
+    dim_t padL() const { return desc_.padding[0][ndims() - 3]; }
+    dim_t padR() const { return desc_.padding[1][ndims() - 3]; }
+
+    int ndims() const { return src_desc().ndims; }
+    bool is_3d() const { return ndims() == 5; }
+
+    bool has_zero_dim_memory() const
+    { return memory_desc_wrapper(src_desc()).has_zero_dim(); }
+
+    bool is_fwd() const {
+        return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+                prop_kind::forward_inference);
+    }
+
+protected:
+    pooling_desc_t desc_;
+    const pooling_fwd_pd_t *hint_fwd_pd_;
+
+    memory_desc_t ws_md_;
+
+    void init_default_ws() {
+        ws_md_ = is_fwd() ? *dst_md() : *diff_dst_md();
+        ws_md_.data_type = indices_data_type();
+    }
+
+    data_type_t indices_data_type() const {
+        /* the simplest way to express 256... */
+        const int u8_max = nstl::numeric_limits<
+            typename prec_traits<data_type::u8>::type>::max();
+        return utils::array_product(desc()->kernel, ndims()) <= u8_max
+            ? data_type::u8 : data_type::s32;
+    }
+
+private:
+    const memory_desc_t &src_desc() const
+    { return is_fwd() ? desc_.src_desc : desc_.diff_src_desc; }
+    const memory_desc_t &dst_desc() const
+    { return is_fwd() ? desc_.dst_desc : desc_.diff_dst_desc; }
+};
+
+struct pooling_fwd_pd_t: public pooling_pd_t {
+    typedef pooling_fwd_pd_t base_class;
+    typedef pooling_fwd_pd_t hint_class;
+
+    pooling_fwd_pd_t(engine_t *engine,
+            const pooling_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const pooling_fwd_pd_t *hint_fwd_pd)
+        : pooling_pd_t(engine, adesc, attr, hint_fwd_pd)
+        , src_md_(desc_.src_desc)
+        , dst_md_(desc_.dst_desc)
+    {}
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (arg == MKLDNN_ARG_SRC)
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_DST)
+            return arg_usage_t::output;
+
+        if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
+            return arg_usage_t::output;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *src_md(int index = 0) const override
+    { return index == 0 ? &src_md_ : nullptr; }
+    virtual const memory_desc_t *dst_md(int index = 0) const override
+    { return index == 0 ? &dst_md_ : nullptr; }
+    virtual const memory_desc_t *workspace_md(int index = 0) const override
+    { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
+
+    virtual int n_inputs() const override { return 1; }
+    virtual int n_outputs() const override
+    { return 1 + (workspace_md() != nullptr); }
+
+protected:
+    memory_desc_t src_md_;
+    memory_desc_t dst_md_;
+
+    virtual status_t set_default_params() {
+        if (dst_md()->format_kind != format_kind::any)
+            return status::success;
+
+        if (src_md()->format_kind != format_kind::blocked)
+            return status::unimplemented;
+
+        return memory_desc_init_by_blocking_desc(dst_md_,
+                src_md_.format_desc.blocking);
+    }
+};
+
+struct pooling_bwd_pd_t: public pooling_pd_t {
+    typedef pooling_bwd_pd_t base_class;
+    typedef pooling_fwd_pd_t hint_class;
+
+    pooling_bwd_pd_t(engine_t *engine,
+            const pooling_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const pooling_fwd_pd_t *hint_fwd_pd)
+        : pooling_pd_t(engine, adesc, attr, hint_fwd_pd)
+        , diff_src_md_(desc_.diff_src_desc)
+        , diff_dst_md_(desc_.diff_dst_desc)
+    {}
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (arg == MKLDNN_ARG_DIFF_DST)
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_DIFF_SRC)
+            return arg_usage_t::output;
+
+        if (arg == MKLDNN_ARG_WORKSPACE && (workspace_md() != nullptr))
+            return arg_usage_t::input;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *diff_src_md(int index = 0) const override
+    { return index == 0 ? &diff_src_md_ : nullptr; }
+    virtual const memory_desc_t *diff_dst_md(int index = 0) const override
+    { return index == 0 ? &diff_dst_md_ : nullptr; }
+    virtual const memory_desc_t *workspace_md(int index = 0) const override
+    { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
+
+    virtual int n_inputs() const override
+    { return 1 + (workspace_md() != nullptr); }
+    virtual int n_outputs() const override { return 1; }
+
+protected:
+    memory_desc_t diff_src_md_;
+    memory_desc_t diff_dst_md_;
+
+    virtual status_t set_default_params() {
+        if (diff_src_md()->format_kind != format_kind::any)
+            return status::success;
+
+        if (diff_dst_md()->format_kind != format_kind::blocked)
+            return status::unimplemented;
+
+        return memory_desc_init_by_blocking_desc(diff_src_md_,
+                diff_dst_md_.format_desc.blocking);
+    }
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 103 - 0
thirdparty/oidn/mkl-dnn/src/common/primitive.cpp

@@ -0,0 +1,103 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "primitive_desc.hpp"
+#include "primitive.hpp"
+#include "type_helpers.hpp"
+#include "stream.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::primitive_kind;
+
+namespace {
+// XXX: this is a huge hammer. This disables all and any msan checks on
+// primitives outputs.
+//
+// A proper approach would be an implementation-specific unpoisoning.
+void unpoison_outputs(const exec_args_t &args) {
+    for(const auto &arg: args) {
+        if (arg.second.is_const) continue;
+        auto *mem = arg.second.mem;
+        void *p;
+        mem->get_data_handle(&p);
+        size_t s = memory_desc_wrapper(*mem->md()).size();
+        msan_unpoison(p, s);
+    }
+}
+}
+
+status_t mkldnn_primitive_desc_destroy(primitive_desc_t *primitive_desc) {
+    if (primitive_desc) delete primitive_desc;
+    return success;
+}
+
+status_t mkldnn_primitive_create(primitive_t **primitive,
+        const primitive_desc_t *primitive_desc) {
+    if (utils::any_null(primitive, primitive_desc))
+        return invalid_arguments;
+    return primitive_desc->create_primitive(primitive);
+}
+
+status_t mkldnn_primitive_execute(const primitive_t *primitive,
+        stream_t *stream, int nargs, const mkldnn_exec_arg_t *c_args) {
+    bool ok = true
+        && !utils::any_null(primitive, stream)
+        && primitive->engine() == stream->engine()
+        && IMPLICATION(nargs > 0, c_args != nullptr);
+    if (!ok) return invalid_arguments;
+
+    exec_args_t args;
+    status_t status = cvt_primtive_args(primitive->pd(), nargs, c_args, args);
+    if (status != status::success) return status;
+
+    exec_ctx_t ctx(stream, std::move(args));
+
+    if (mkldnn_verbose()->level) {
+        double ms = get_msec();
+        status = primitive->execute(ctx);
+        ms = get_msec() - ms;
+        printf("mkldnn_verbose,exec,%s,%g\n", primitive->pd()->info(), ms);
+        fflush(0);
+    } else {
+        status = primitive->execute(ctx);
+    }
+
+    if (msan_enabled) unpoison_outputs(ctx.args());
+
+    return status;
+}
+
+status_t mkldnn_primitive_get_primitive_desc(const primitive_t *primitive,
+        const primitive_desc_t **primitive_desc) {
+    if (utils::any_null(primitive, primitive_desc))
+        return invalid_arguments;
+    return safe_ptr_assign<const primitive_desc_t>(*primitive_desc,
+            primitive->pd());
+}
+
+status_t mkldnn_primitive_destroy(primitive_t *primitive) {
+    if (primitive != nullptr)
+        delete primitive;
+    return success;
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 76 - 0
thirdparty/oidn/mkl-dnn/src/common/primitive.hpp

@@ -0,0 +1,76 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef PRIMITIVE_HPP
+#define PRIMITIVE_HPP
+
+#include <assert.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "primitive_desc.hpp"
+#include "primitive_exec_types.hpp"
+
+/** \brief A pure virtual primitive class
+ *
+ * Primitive contains links to its inputs & outputs, though it does not track
+ * their readiness on execution step.
+ *
+ * @remark @b Rational.
+ *   Dependencies are essential through-out the whole MKL-DNN library, so it
+ *   makes sense to include them on the very low level. On the other hand,
+ *   tracking them should be a task for corresponding essence, like scheduler,
+ *   stream or whatever. Primitive itself should know nothing about the
+ *   environment it is running in.
+ *
+ * @note
+ *   To make user experience better we should provide API which allows
+ *   achieving the best (or good enough) performance when creating primitives
+ *   in natural order: i.e. from bottom to top for forward pass and from top to
+ *   bottom for backward pass. Please consider restriction [1] in Level 0.
+ */
+struct mkldnn_primitive: public mkldnn::impl::c_compatible {
+    mkldnn_primitive(const mkldnn::impl::primitive_desc_t *pd)
+        : pd_(pd->clone()) {}
+    virtual ~mkldnn_primitive() { delete pd_; }
+
+    /** returns primitive's engine */
+    mkldnn::impl::engine_t *engine() const { return pd_->engine(); }
+    /** returns primitive's inputs */
+    const mkldnn::impl::primitive_desc_t *pd() const { return pd_; }
+    /** returns primitive's kind */
+    mkldnn::impl::primitive_kind_t kind() const { return pd_->kind(); }
+
+    /** executes primitive with execution context @p ctx */
+    virtual mkldnn::impl::status_t execute(const mkldnn::impl::exec_ctx_t &ctx)
+        const = 0;
+
+protected:
+    const mkldnn::impl::primitive_desc_t *pd_;
+
+private:
+    mkldnn_primitive() = delete;
+    mkldnn_primitive(const mkldnn_primitive &) = delete;
+    mkldnn_primitive(mkldnn_primitive &&) = delete;
+    mkldnn_primitive &operator=(const mkldnn_primitive &) = delete;
+    mkldnn_primitive &operator=(mkldnn_primitive &&) = delete;
+};
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 290 - 0
thirdparty/oidn/mkl-dnn/src/common/primitive_attr.cpp

@@ -0,0 +1,290 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_attr.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::utils;
+
+namespace mkldnn {
+namespace impl {
+
+status_t scales_t::set(dim_t count, int mask, const float *scales) {
+    cleanup();
+
+    count_ = count;
+    mask_ = mask;
+
+    if (count_ == 1) {
+        scales_ = scales_buf_;
+        utils::array_set(scales_, scales[0], scales_buf_size);
+    } else {
+        scales_ = (float *)impl::malloc(count_ * sizeof(*scales_), 64);
+        if (scales_ == nullptr)
+            return status::out_of_memory;
+
+        for (dim_t c = 0; c < count_; ++c)
+            scales_[c] = scales[c];
+    }
+
+    return status::success;
+}
+
+}
+}
+
+status_t post_ops_t::append_sum(float scale) {
+    if (len_ == capacity)
+        return out_of_memory;
+
+    entry_[len_].kind = primitive_kind::sum;
+    entry_[len_].sum.scale = scale;
+
+    len_++;
+
+    return success;
+}
+
+status_t post_ops_t::append_eltwise(float scale, alg_kind_t alg, float alpha,
+        float beta) {
+    using namespace mkldnn::impl::alg_kind;
+    bool known_alg = one_of(alg, eltwise_relu, eltwise_tanh, eltwise_elu,
+            eltwise_square, eltwise_abs, eltwise_sqrt, eltwise_linear,
+            eltwise_bounded_relu, eltwise_soft_relu, eltwise_logistic);
+    if (!known_alg)
+        return invalid_arguments;
+
+    if (len_ == capacity)
+        return out_of_memory;
+
+    entry_[len_].kind = primitive_kind::eltwise;
+    entry_[len_].eltwise.scale = scale;
+    entry_[len_].eltwise.alg = alg;
+    entry_[len_].eltwise.alpha = alpha;
+    entry_[len_].eltwise.beta = beta;
+
+    len_++;
+
+    return success;
+}
+
+status_t primitive_attr_t::set_scratchpad_mode(
+        scratchpad_mode_t scratchpad_mode) {
+    using namespace mkldnn::impl::scratchpad_mode;
+
+    const bool ok = one_of(scratchpad_mode, library, user);
+    if (!ok)
+        return invalid_arguments;
+
+    scratchpad_mode_ = scratchpad_mode;
+    return success;
+}
+
+status_t primitive_attr_t::set_post_ops(const post_ops_t &post_ops) {
+    this->post_ops_ = post_ops;
+    return success;
+}
+
+/* Public C API */
+
+status_t mkldnn_primitive_attr_create(primitive_attr_t **attr) {
+    if (attr == nullptr)
+        return invalid_arguments;
+
+    return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
+            new mkldnn_primitive_attr);
+}
+
+status_t mkldnn_primitive_attr_clone(primitive_attr_t **attr,
+        const primitive_attr_t *existing_attr) {
+    if (any_null(attr, existing_attr))
+        return invalid_arguments;
+
+    return safe_ptr_assign<mkldnn_primitive_attr>(*attr,
+            existing_attr->clone());
+}
+
+status_t mkldnn_primitive_attr_destroy(primitive_attr_t *attr) {
+    if (attr)
+        delete attr;
+
+    return success;
+}
+
+status_t mkldnn_primitive_attr_get_scratchpad_mode(
+        const primitive_attr_t *attr, scratchpad_mode_t *scratchpad_mode) {
+    if (any_null(attr, scratchpad_mode))
+        return invalid_arguments;
+
+    *scratchpad_mode = attr->scratchpad_mode_;
+
+    return success;
+}
+
+status_t mkldnn_primitive_attr_set_scratchpad_mode(
+        primitive_attr_t *attr, scratchpad_mode_t scratchpad_mode) {
+    if (any_null(attr))
+        return invalid_arguments;
+
+    return attr->set_scratchpad_mode(scratchpad_mode);
+}
+
+status_t mkldnn_primitive_attr_get_output_scales(const primitive_attr_t *attr,
+        dim_t *count, int *mask, const float **scales) {
+    if (any_null(attr, count, mask, scales))
+        return invalid_arguments;
+
+    *count = attr->output_scales_.count_;
+    *mask = attr->output_scales_.mask_;
+    *scales = attr->output_scales_.scales_;
+
+    return success;
+}
+
+status_t mkldnn_primitive_attr_set_output_scales(primitive_attr_t *attr,
+        dim_t count, int mask, const float *scales) {
+    bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
+    if (!ok)
+        return invalid_arguments;
+
+    return attr->output_scales_.set(count, mask, scales);
+}
+
+status_t mkldnn_primitive_attr_get_post_ops(const primitive_attr_t *attr,
+        const post_ops_t **post_ops) {
+    if (any_null(attr, post_ops))
+        return invalid_arguments;
+
+    *post_ops = &attr->post_ops_;
+    return success;
+}
+
+status_t mkldnn_primitive_attr_set_post_ops(primitive_attr_t *attr,
+        const post_ops_t *post_ops) {
+    if (any_null(attr, post_ops))
+        return invalid_arguments;
+
+    return attr->set_post_ops(*post_ops);
+}
+
+status_t mkldnn_post_ops_create(post_ops_t **post_ops) {
+    if (post_ops == nullptr)
+        return invalid_arguments;
+
+    return safe_ptr_assign<mkldnn_post_ops>(*post_ops, new mkldnn_post_ops);
+}
+
+status_t mkldnn_post_ops_destroy(post_ops_t *post_ops) {
+    if (post_ops)
+        delete post_ops;
+
+    return success;
+}
+
+int mkldnn_post_ops_len(const post_ops_t *post_ops) {
+    if (post_ops)
+        return post_ops->len_;
+
+    return 0;
+}
+
+primitive_kind_t mkldnn_post_ops_get_kind(const post_ops_t *post_ops,
+        int index) {
+    bool ok = post_ops && 0 <= index && index < post_ops->len_;
+    if (!ok)
+        return primitive_kind::undefined;
+
+    return post_ops->entry_[index].kind;
+}
+
+status_t mkldnn_post_ops_append_sum(post_ops_t *post_ops, float scale) {
+    if (post_ops == nullptr)
+        return invalid_arguments;
+
+    return post_ops->append_sum(scale);
+}
+
+namespace {
+bool simple_get_params_check(const post_ops_t *post_ops, int index,
+        primitive_kind_t kind) {
+    bool ok = true
+        && post_ops != nullptr
+        && 0 <= index
+        && index < post_ops->len_
+        && post_ops->entry_[index].kind == kind;
+   return ok;
+}
+}
+
+status_t mkldnn_post_ops_get_params_sum(const post_ops_t *post_ops, int index,
+        float *scale) {
+    bool ok = true
+        && simple_get_params_check(post_ops, index, primitive_kind::sum)
+        && !any_null(scale);
+    if (!ok)
+        return invalid_arguments;
+
+    *scale = post_ops->entry_[index].sum.scale;
+    return success;
+}
+
+status_t mkldnn_post_ops_append_eltwise(post_ops_t *post_ops, float scale,
+        alg_kind_t kind, float alpha, float beta) {
+    if (post_ops == nullptr)
+        return invalid_arguments;
+
+    return post_ops->append_eltwise(scale, kind, alpha, beta);
+}
+
+status_t mkldnn_post_ops_get_params_eltwise(const post_ops_t *post_ops,
+        int index, float *scale, alg_kind_t *alg, float *alpha, float *beta) {
+    bool ok = true
+        && simple_get_params_check(post_ops, index, primitive_kind::eltwise)
+        && !any_null(scale, alpha, beta);
+    if (!ok)
+        return invalid_arguments;
+
+    const auto &e = post_ops->entry_[index].eltwise;
+    *scale = e.scale;
+    *alg = e.alg;
+    *alpha = e.alpha;
+    *beta = e.beta;
+
+    return success;
+}
+
+status_t mkldnn_primitive_attr_set_rnn_data_qparams(
+        primitive_attr_t *attr, const float scale, const float shift) {
+    if (attr == nullptr)
+        return invalid_arguments;
+
+    return attr->rnn_data_qparams_.set(scale, shift);
+}
+
+status_t mkldnn_primitive_attr_set_rnn_weights_qparams(
+        primitive_attr_t *attr, dim_t count, int mask, const float *scales) {
+    bool ok = !any_null(attr, scales) && count > 0 && mask >= 0;
+    if (!ok)
+        return invalid_arguments;
+
+    return attr->rnn_weights_qparams_.set(count, mask, scales);
+}

+ 183 - 0
thirdparty/oidn/mkl-dnn/src/common/primitive_attr.hpp

@@ -0,0 +1,183 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef PRIMITIVE_ATTR_HPP
+#define PRIMITIVE_ATTR_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct rnn_data_qparams_t : public c_compatible {
+    rnn_data_qparams_t() : scale_(1.), shift_(0.) {}
+    bool has_default_values() const { return (scale_ == 1. && shift_ == 0.); }
+
+    status_t set(float scale, float shift) {
+        scale_ = scale;
+        shift_ = shift;
+        return status::success;
+    }
+
+    float scale_;
+    float shift_;
+};
+
+struct scales_t: public c_compatible {
+    scales_t(): count_(1), mask_(0), scales_(scales_buf_)
+    { set(1.); }
+
+    scales_t(const scales_t &rhs): scales_t()
+    { set(rhs.count_, rhs.mask_, rhs.scales_); }
+
+    ~scales_t() { cleanup(); }
+
+    scales_t &operator=(const scales_t &rhs) {
+        if (&rhs == this)
+            return *this;
+        status_t status = set(rhs.count_, rhs.mask_, rhs.scales_);
+        assert(status == status::success);
+        (void)status;
+        return *this;
+    }
+
+    bool has_default_values() const {
+        for (dim_t c = 0; c < count_; ++c) {
+            if(scales_[c] != 1.) return false;
+        }
+        return true;
+    }
+
+    status_t set(dim_t count, int mask, const float *scales);
+    status_t set(float single_scale) { return this->set(1, 0, &single_scale); }
+
+    dim_t count_;
+    int mask_;
+    float *scales_;
+
+private:
+    enum { scales_buf_size = 16 };
+    float scales_buf_[scales_buf_size];
+
+    void cleanup() {
+        if (scales_ != scales_buf_ && scales_ != nullptr)
+            impl::free(scales_);
+
+        count_ = 1;
+        mask_ = 0;
+        scales_ = scales_buf_;
+    }
+};
+
+}
+}
+
+struct mkldnn_post_ops: public mkldnn::impl::c_compatible {
+    struct entry_t {
+        struct eltwise_t {
+            mkldnn::impl::alg_kind_t alg;
+            float scale, alpha, beta;
+        };
+
+        mkldnn::impl::primitive_kind_t kind;
+        union {
+            struct { float scale; } sum;
+            eltwise_t eltwise;
+        };
+
+        bool is_eltwise(bool require_scale_one = true) const {
+            using namespace mkldnn::impl;
+            return kind == primitive_kind::eltwise
+                && IMPLICATION(require_scale_one, eltwise.scale == 1.f);
+        }
+
+        bool is_relu(bool require_scale_one = true,
+                bool require_nslope_zero = true) const {
+            using namespace mkldnn::impl;
+            return is_eltwise(require_scale_one)
+                && eltwise.alg == alg_kind::eltwise_relu
+                && IMPLICATION(require_nslope_zero, eltwise.alpha == 0.f);
+        }
+
+        bool is_sum(bool require_scale_one = true) const {
+            using namespace mkldnn::impl;
+            return kind == primitive_kind::sum
+                && IMPLICATION(require_scale_one, sum.scale == 1.f);
+        }
+    };
+
+    mkldnn_post_ops(): len_(0) {}
+
+    mkldnn::impl::status_t append_sum(float scale);
+    mkldnn::impl::status_t append_eltwise(float scale,
+            mkldnn::impl::alg_kind_t alg, float alpha, float beta);
+
+    int find(mkldnn::impl::primitive_kind_t kind, int start = 0,
+            int stop = -1) const {
+        if (stop == -1) stop = len_;
+        stop = mkldnn::impl::nstl::min(stop, len_);
+        for (int idx = start; idx < stop; ++idx)
+            if (entry_[idx].kind == kind) return idx;
+        return -1;
+    }
+
+    bool has_default_values() const { return len_ == 0; }
+
+    bool contain(mkldnn::impl::primitive_kind_t kind, int index) const
+    { return find(kind, index, index + 1) == index; }
+
+    enum { capacity = 4 };
+
+    int len_;
+    entry_t entry_[capacity];
+};
+
+struct mkldnn_primitive_attr: public mkldnn::impl::c_compatible {
+    mkldnn_primitive_attr()
+        : scratchpad_mode_(mkldnn::impl::scratchpad_mode::library)
+    {}
+
+    mkldnn_primitive_attr *clone() const
+    { return new mkldnn_primitive_attr(*this); }
+
+    /** Returns true if the attributes have default values.
+     *
+     * @note The scratchpad_mode_ is not take into account */
+    bool has_default_values() const {
+       return true
+            && output_scales_.has_default_values()
+            && post_ops_.has_default_values()
+            && rnn_data_qparams_.has_default_values()
+            && rnn_weights_qparams_.has_default_values();
+    }
+
+    mkldnn::impl::status_t set_scratchpad_mode(
+            mkldnn::impl::scratchpad_mode_t scratchpad_mode);
+    mkldnn::impl::status_t set_post_ops(
+            const mkldnn::impl::post_ops_t &post_ops);
+
+    mkldnn::impl::scratchpad_mode_t scratchpad_mode_;
+    mkldnn::impl::scales_t output_scales_;
+    mkldnn::impl::post_ops_t post_ops_;
+    mkldnn::impl::rnn_data_qparams_t rnn_data_qparams_;
+    mkldnn::impl::scales_t rnn_weights_qparams_;
+};
+
+#endif

+ 78 - 0
thirdparty/oidn/mkl-dnn/src/common/primitive_desc.cpp

@@ -0,0 +1,78 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "nstl.hpp"
+#include "primitive_desc.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+
+status_t primitive_desc_t::query(query_t what, int idx, void *result) const {
+    auto safe_ret_md = [&](const memory_desc_t *_) {
+        if (_ == nullptr) return not_required;
+        *(const memory_desc_t **)result = _;
+        return success;
+    };
+
+    switch (what) {
+        case query::engine: *(engine_t**)result = engine(); break;
+        case query::primitive_kind: *(primitive_kind_t*)result = kind(); break;
+
+        case query::scratchpad_engine:
+            *(engine_t**)result = scratchpad_engine(); break;
+
+        case query::memory_consumption_s64:
+            *(dim_t *)result = scratchpad_size(scratchpad_mode::library); break;
+
+        case query::op_d:
+            if (idx != 0 || op_desc() == nullptr) return invalid_arguments;
+            *(const_c_op_desc_t *)result
+                = static_cast<const_c_op_desc_t>(op_desc()); break;
+
+        case query::src_md: return safe_ret_md(src_md(idx));
+        case query::diff_src_md: return safe_ret_md(diff_src_md(idx));
+        case query::dst_md: return safe_ret_md(dst_md(idx));
+        case query::diff_dst_md: return safe_ret_md(diff_dst_md(idx));
+        case query::weights_md: return safe_ret_md(weights_md(idx));
+        case query::diff_weights_md: return safe_ret_md(diff_weights_md(idx));
+        case query::workspace_md:
+            if (idx != 0) return status::invalid_arguments;
+            return safe_ret_md(workspace_md(idx));
+        case query::scratchpad_md:
+            if (idx != 0) return status::invalid_arguments;
+            return safe_ret_md(scratchpad_md(idx));
+
+        case query::num_of_inputs_s32: *(int*)result = n_inputs(); break;
+        case query::num_of_outputs_s32: *(int*)result = n_outputs(); break;
+
+        case query::impl_info_str: *(const char **)result = name(); break;
+
+        default: return unimplemented;
+    }
+    return success;
+}
+
+status_t mkldnn_primitive_desc_get_attr(const primitive_desc_t *primitive_desc,
+        const primitive_attr_t **attr) {
+    if (utils::any_null(primitive_desc, attr))
+        return invalid_arguments;
+
+    *attr = primitive_desc->attr();
+    return success;
+}

+ 174 - 0
thirdparty/oidn/mkl-dnn/src/common/primitive_desc.hpp

@@ -0,0 +1,174 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef PRIMITIVE_DESC_HPP
+#define PRIMITIVE_DESC_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "memory_tracking.hpp"
+#include "nstl.hpp"
+#include "type_helpers.hpp"
+#include "primitive_attr.hpp"
+#include "verbose.hpp"
+
+struct mkldnn_primitive_desc: public mkldnn::impl::c_compatible {
+    using md_t = mkldnn::impl::memory_desc_t;
+
+    mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
+            const mkldnn::impl::primitive_attr_t *attr,
+            mkldnn::impl::primitive_kind_t kind)
+        : engine_(engine), attr_(*attr), kind_(kind) { info_[0] = '\0'; }
+
+    mkldnn_primitive_desc(mkldnn::impl::engine_t *engine,
+            mkldnn::impl::primitive_kind_t kind)
+        : engine_(engine), kind_(kind) { info_[0] = '\0'; }
+
+    virtual mkldnn_primitive_desc *clone() const = 0;
+    virtual ~mkldnn_primitive_desc() {}
+
+    const mkldnn::impl::primitive_attr_t *attr() const { return &attr_; }
+    mkldnn::impl::engine_t *engine() const { return engine_; }
+    mkldnn::impl::primitive_kind_t kind() const { return kind_; }
+
+    virtual void init_info() {}
+    const char *info() const { return info_; }
+
+    mkldnn::impl::memory_tracking::registry_t &scratchpad_registry()
+    { return scratchpad_registry_; }
+    const mkldnn::impl::memory_tracking::registry_t &scratchpad_registry() const
+    { return scratchpad_registry_; }
+    virtual mkldnn::impl::engine_t *scratchpad_engine() const
+    { return engine_; }
+
+    virtual const mkldnn::impl::op_desc_t *op_desc() const { return nullptr; }
+
+    enum class arg_usage_t { unused, input, output };
+    virtual arg_usage_t arg_usage(
+            mkldnn::impl::primitive_arg_index_t arg) const {
+        using mkldnn::impl::types::is_zero_md;
+        if (arg == MKLDNN_ARG_SCRATCHPAD && !is_zero_md(scratchpad_md()))
+            return arg_usage_t::output;
+        return arg_usage_t::unused;
+    }
+
+#   define DECLARE_MD_STUB(stub) \
+    virtual const mkldnn::impl::memory_desc_t *stub(int idx = 0) const \
+    { return nullptr; }
+
+    DECLARE_MD_STUB(input_md); DECLARE_MD_STUB(output_md);
+    DECLARE_MD_STUB(src_md); DECLARE_MD_STUB(diff_src_md);
+    DECLARE_MD_STUB(dst_md); DECLARE_MD_STUB(diff_dst_md);
+    DECLARE_MD_STUB(weights_md); DECLARE_MD_STUB(diff_weights_md);
+    DECLARE_MD_STUB(workspace_md);
+#   undef DECLARE_MD_STUB
+
+    const mkldnn::impl::memory_desc_t *scratchpad_md(int idx = 0) const {
+        return idx == 0 ? &scratchpad_md_ : nullptr;
+    }
+
+    virtual void init_scratchpad_md() {
+        auto size = scratchpad_size(mkldnn::impl::scratchpad_mode::user);
+        mkldnn::impl::dims_t dims = { size };
+        mkldnn_memory_desc_init_by_tag(&scratchpad_md_, size ? 1 : 0, dims,
+                mkldnn::impl::data_type::u8, mkldnn_x);
+    }
+
+    /** returns the scratchpad size for the given scratchpad mode. */
+    mkldnn::impl::dim_t scratchpad_size(
+            mkldnn::impl::scratchpad_mode_t mode) const {
+        if (mode != attr_.scratchpad_mode_) return 0;
+        return scratchpad_registry().size();
+    }
+
+    virtual int n_inputs() const { return 0; }
+    virtual int n_outputs() const { return 0; }
+
+    virtual mkldnn::impl::status_t query(mkldnn::impl::query_t what, int idx,
+            void *result) const;
+
+    virtual mkldnn::impl::status_t create_primitive(
+            mkldnn::impl::primitive_t **primitive) const = 0;
+
+    virtual const char *name() const { return "mkldnn_primitive_desc"; }
+
+    /* static magic */
+
+    template<typename pd_t>
+    static mkldnn::impl::status_t create(mkldnn::impl::primitive_desc_t **pd,
+            const mkldnn::impl::op_desc_t *adesc,
+            const mkldnn::impl::primitive_attr_t *attr,
+            mkldnn::impl::engine_t *engine,
+            const mkldnn::impl::primitive_desc_t *hint_fwd) {
+        using namespace mkldnn::impl;
+        using namespace mkldnn::impl::status;
+        using pd_op_desc_t = typename pkind_traits<pd_t::base_pkind>::desc_type;
+        if (adesc->kind != pd_t::base_pkind) return invalid_arguments;
+        assert(hint_fwd ? hint_fwd->kind() == pd_t::base_pkind : true);
+        auto hint =
+            reinterpret_cast<const typename pd_t::hint_class *>(hint_fwd);
+        auto _pd = new pd_t(engine, (const pd_op_desc_t *)adesc, attr, hint);
+        if (_pd == nullptr) return out_of_memory;
+        if (_pd->init() != success) { delete _pd; return unimplemented; }
+        _pd->init_info();
+        _pd->init_scratchpad_md();
+        *pd = _pd;
+        return success;
+    }
+
+protected:
+    mkldnn::impl::engine_t *engine_;
+    mkldnn::impl::primitive_attr_t attr_;
+    mkldnn::impl::primitive_kind_t kind_;
+
+    mkldnn::impl::memory_desc_t scratchpad_md_;
+
+    char info_[MKLDNN_VERBOSE_BUF_LEN];
+
+    mkldnn::impl::memory_tracking::registry_t scratchpad_registry_;
+
+protected:
+    /** compares ws between fwd_pd and this (make sense to use for bwd_pd)
+     * Expectation: this already set workspace, and this workspace should
+     *              exactly match the one from fwd_pd */
+    bool compare_ws(const mkldnn_primitive_desc *fwd_pd) const {
+        using namespace mkldnn::impl;
+        if (!workspace_md()) return true; // the impl lives fine w/o workspace
+        return fwd_pd && fwd_pd->workspace_md()
+            && *fwd_pd->workspace_md() == *workspace_md();
+    }
+};
+
+#define DECLARE_COMMON_PD_t(impl_name, ...) \
+    virtual pd_t *clone() const override { return new pd_t(*this); } \
+    virtual status_t create_primitive(primitive_t **p) const override { \
+        double ms = get_msec(); \
+        auto ret = safe_ptr_assign<primitive_t>(*p, new (__VA_ARGS__)(this)); \
+        ms = get_msec() - ms; \
+        if (mkldnn_verbose()->level >= 2) { \
+            printf("mkldnn_verbose,create,%s,%g\n", this->info(), ms); \
+            fflush(0); \
+        } \
+        return ret; \
+    } \
+    virtual const char *name() const override { return impl_name; }
+#define DECLARE_COMMON_PD_T(impl_name, ...) \
+    DECLARE_COMMON_PD_t(impl_name, __VA_ARGS__)
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 90 - 0
thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.cpp

@@ -0,0 +1,90 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "memory.hpp"
+#include "primitive.hpp"
+#include "primitive_exec_types.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs,
+        const mkldnn_exec_arg_t *c_args, exec_args_t &args) {
+    using namespace status;
+
+    if (!IMPLICATION(nargs > 0, c_args != nullptr)) return invalid_arguments;
+
+    int n_inputs = 0;
+    int n_outputs = 0;
+
+    for (int i = 0; i < nargs; ++i) {
+        primitive_arg_index_t arg = c_args[i].arg;
+        auto *mem = c_args[i].memory;
+
+        switch (pd->arg_usage(arg)) {
+        case primitive_desc_t::arg_usage_t::input:
+            if (args.count(arg) != 0) return invalid_arguments;
+            args[arg] = {mem, true};
+            n_inputs++;
+            break;
+        case primitive_desc_t::arg_usage_t::output:
+            if (args.count(arg) != 0) return invalid_arguments;
+            args[arg] = {mem, false};
+            n_outputs++;
+            break;
+        case primitive_desc_t::arg_usage_t::unused:
+            break;
+        }
+    }
+
+    bool scratchpad_required = !types::is_zero_md(pd->scratchpad_md());
+
+    if (n_inputs != pd->n_inputs()) return invalid_arguments;
+    if (n_outputs != pd->n_outputs() + (scratchpad_required ? 1 : 0))
+        return invalid_arguments;
+
+    return success;
+}
+
+const void *exec_ctx_t::input(primitive_arg_index_t arg) const {
+    if (args_.count(arg) != 1) return nullptr;
+    const auto ma = args_.at(arg);
+    assert(ma.is_const);
+    void *ptr;
+    status_t status = ma.mem->get_data_handle(&ptr);
+    assert(status == status::success); MAYBE_UNUSED(status);
+    return ptr;
+}
+
+void *exec_ctx_t::output(primitive_arg_index_t arg) const {
+    if (args_.count(arg) != 1) return nullptr;
+    const auto ma = args_.at(arg);
+    assert(!ma.is_const);
+    void *ptr;
+    status_t status = ma.mem->get_data_handle(&ptr);
+    assert(status == status::success); MAYBE_UNUSED(status);
+    return ptr;
+}
+
+const memory_t *exec_ctx_t::memory(primitive_arg_index_t arg) const {
+    assert(args_.count(arg) == 1);
+    const auto ma = args_.at(arg);
+    assert(!ma.is_const);
+    return ma.mem;
+}
+
+}
+}

+ 68 - 0
thirdparty/oidn/mkl-dnn/src/common/primitive_exec_types.hpp

@@ -0,0 +1,68 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef PRIMITIVE_EXEC_TYPES_HPP
+#define PRIMITIVE_EXEC_TYPES_HPP
+
+#include <unordered_map>
+
+#include "mkldnn_types.h"
+
+#include "c_types_map.hpp"
+#include "memory.hpp"
+#include "primitive_desc.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct memory_arg_t {
+    memory_t *mem;
+    bool is_const;
+};
+
+using exec_args_t = std::unordered_map<primitive_arg_index_t, memory_arg_t>;
+
+status_t cvt_primtive_args(const primitive_desc_t *pd, int nargs,
+        const mkldnn_exec_arg_t *c_args, exec_args_t &args);
+
+/** Primitive execution context (helps passing stream, memories, and events. */
+struct exec_ctx_t {
+    exec_ctx_t(const exec_ctx_t &) = default;
+    exec_ctx_t(exec_ctx_t &&) = default;
+
+    exec_ctx_t(stream_t *stream): stream_(stream) {}
+    exec_ctx_t(stream_t *stream, exec_args_t &&args)
+        : stream_(stream)
+        , args_(std::move(args)) {}
+
+    stream_t *stream() const { return stream_; }
+    const exec_args_t &args() const { return args_; }
+
+    /* tentative solution... TODO: replace with functions return memory_t */
+    const void *input(primitive_arg_index_t arg) const;
+    void *output(primitive_arg_index_t arg) const;
+
+    const memory_t *memory(primitive_arg_index_t arg) const;
+
+private:
+    stream_t *stream_;
+    exec_args_t args_;
+};
+
+}
+}
+
+#endif

+ 89 - 0
thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.cpp

@@ -0,0 +1,89 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "primitive_desc.hpp"
+#include "type_helpers.hpp"
+#include "primitive_iterator.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+
+status_t mkldnn_primitive_desc_iterator_create(
+        primitive_desc_iterator_t **iterator, const_c_op_desc_t c_op_desc,
+        const primitive_attr_t *attr, engine_t *engine,
+        const primitive_desc_t *hint_fwd_pd) {
+    const op_desc_t *op_desc = (const op_desc_t *)c_op_desc;
+
+    auto it = new primitive_desc_iterator_t(engine, op_desc, attr, hint_fwd_pd);
+    if (it == nullptr) return out_of_memory;
+
+    ++(*it);
+    if (*it == it->end()) {
+        delete it;
+        return unimplemented;
+    }
+
+    *iterator = it;
+    return success;
+}
+
+status_t mkldnn_primitive_desc_iterator_next(
+        primitive_desc_iterator_t *iterator) {
+    if (iterator == nullptr) return invalid_arguments;
+    ++(*iterator);
+    return *iterator == iterator->end() ? iterator_ends : success;
+}
+
+primitive_desc_t *mkldnn_primitive_desc_iterator_fetch(
+        const primitive_desc_iterator_t *iterator) {
+    if (iterator == nullptr) return nullptr;
+    return *(*iterator);
+}
+
+status_t mkldnn_primitive_desc_clone(primitive_desc_t **primitive_desc,
+        const primitive_desc_t *existing_primitive_desc) {
+    if (utils::any_null(primitive_desc, existing_primitive_desc))
+        return invalid_arguments;
+    return safe_ptr_assign<primitive_desc_t>(*primitive_desc,
+            existing_primitive_desc->clone());
+}
+
+status_t mkldnn_primitive_desc_iterator_destroy(
+        primitive_desc_iterator_t *iterator) {
+    if (iterator != nullptr)
+        delete iterator;
+    return success;
+}
+
+status_t mkldnn_primitive_desc_create(primitive_desc_t **primitive_desc,
+        const_c_op_desc_t c_op_desc, const primitive_attr_t *attr,
+        engine_t *engine, const primitive_desc_t *hint_fwd_pd) {
+    const op_desc_t *op_desc = (const op_desc_t *)c_op_desc;
+
+    mkldnn_primitive_desc_iterator it(engine, op_desc, attr, hint_fwd_pd);
+    ++it;
+    if (it == it.end()) return unimplemented;
+
+    return safe_ptr_assign<primitive_desc_t>(*primitive_desc, *it);
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 79 - 0
thirdparty/oidn/mkl-dnn/src/common/primitive_iterator.hpp

@@ -0,0 +1,79 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+#ifndef PRIMITIVE_ITERATOR_HPP
+#define PRIMITIVE_ITERATOR_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "primitive_desc.hpp"
+#include "type_helpers.hpp"
+
+struct mkldnn_primitive_desc_iterator: public mkldnn::impl::c_compatible {
+    using pd_create_f = mkldnn::impl::engine_t::primitive_desc_create_f;
+
+    mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, const mkldnn::impl::op_desc_t *op_desc,
+            const mkldnn::impl::primitive_attr_t *attr, const mkldnn::impl::primitive_desc_t *hint_fwd_pd)
+        : idx_(-1), engine_(engine), pd_(nullptr), op_desc_(op_desc)
+        , attr_(attr ? *attr : mkldnn::impl::primitive_attr_t()), hint_fwd_pd_(hint_fwd_pd)
+        , impl_list_(engine_->get_implementation_list()), last_idx_(0)
+    {
+        while (impl_list_[last_idx_] != nullptr) ++last_idx_;
+    }
+    ~mkldnn_primitive_desc_iterator() { if (pd_) delete pd_; }
+
+    bool operator==(const mkldnn::impl::primitive_desc_iterator_t& rhs) const
+    { return idx_ == rhs.idx_ && engine_ == rhs.engine_; }
+    bool operator!=(const mkldnn::impl::primitive_desc_iterator_t& rhs) const
+    { return !operator==(rhs); }
+
+    mkldnn::impl::primitive_desc_iterator_t end() const
+    { return mkldnn_primitive_desc_iterator(engine_, last_idx_); }
+
+    mkldnn::impl::primitive_desc_iterator_t &operator++() {
+        if (pd_) { delete pd_; pd_ = nullptr; }
+        while (++idx_ != last_idx_) {
+            auto s = impl_list_[idx_](&pd_, op_desc_, &attr_, engine_,
+                    hint_fwd_pd_);
+            if (s ==  mkldnn::impl::status::success) break;
+        }
+        return *this;
+    }
+
+    mkldnn::impl::primitive_desc_t *operator*() const {
+        if (*this == end() || pd_ == nullptr) return nullptr;
+        return pd_->clone();
+    }
+
+protected:
+    int idx_;
+    mkldnn::impl::engine_t *engine_;
+    mkldnn::impl::primitive_desc_t *pd_;
+    const mkldnn::impl::op_desc_t *op_desc_;
+    const mkldnn::impl::primitive_attr_t attr_;
+    const mkldnn::impl::primitive_desc_t *hint_fwd_pd_;
+    const pd_create_f *impl_list_;
+    int last_idx_;
+
+private:
+    mkldnn_primitive_desc_iterator(mkldnn::impl::engine_t *engine, int last_idx)
+        : idx_(last_idx), engine_(engine), pd_(nullptr)
+        , op_desc_(nullptr), hint_fwd_pd_(nullptr)
+        , impl_list_(nullptr), last_idx_(last_idx) {}
+};
+
+#endif

+ 59 - 0
thirdparty/oidn/mkl-dnn/src/common/query.cpp

@@ -0,0 +1,59 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "primitive_desc.hpp"
+#include "utils.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+
+status_t mkldnn_primitive_desc_query(const primitive_desc_t *primitive_desc,
+        query_t what, int index, void *result) {
+    if (any_null(primitive_desc, result))
+        return invalid_arguments;
+
+    return primitive_desc->query(what, index, result);
+}
+
+const memory_desc_t *mkldnn_primitive_desc_query_md(
+        const primitive_desc_t *primitive_desc, query_t what, int index) {
+    const memory_desc_t *res_md = nullptr;
+    bool args_ok = true
+        && primitive_desc != nullptr
+        && (what & query::some_md) == query::some_md
+        && what != query::some_md
+        && mkldnn_primitive_desc_query(primitive_desc,
+                what, index, &res_md) == success;
+    return args_ok ? res_md : nullptr;
+}
+
+int mkldnn_primitive_desc_query_s32(const primitive_desc_t *primitive_desc,
+        query_t what, int index) {
+    int res_s32;
+    bool args_ok = primitive_desc != nullptr
+        && one_of(what, query::num_of_inputs_s32, query::num_of_outputs_s32)
+        && mkldnn_primitive_desc_query(primitive_desc, what, index, &res_s32)
+                == success;
+    return args_ok ? res_s32 : 0;
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 68 - 0
thirdparty/oidn/mkl-dnn/src/common/reorder.cpp

@@ -0,0 +1,68 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include <assert.h>
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "engine.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+#include "reorder_pd.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::utils;
+using namespace mkldnn::impl::status;
+
+status_t mkldnn_reorder_primitive_desc_create(
+        primitive_desc_t **reorder_pd,
+        engine_t *src_engine, const memory_desc_t *src_md,
+        engine_t *dst_engine, const memory_desc_t *dst_md,
+        const primitive_attr_t *attr) {
+    if (any_null(reorder_pd, src_engine, src_md, dst_engine, dst_md))
+        return invalid_arguments;
+
+    auto s_ek = src_engine->kind();
+    auto d_ek = dst_engine->kind();
+    if (!IMPLICATION(s_ek != d_ek, one_of(engine_kind::cpu, s_ek, d_ek)))
+        return invalid_arguments;
+
+    auto r_pd = reinterpret_cast<reorder_pd_t **>(reorder_pd);
+    auto s_mdw = memory_desc_wrapper(*src_md);
+    auto d_mdw = memory_desc_wrapper(*dst_md);
+
+    if (!s_mdw.consistent_with(d_mdw))
+        return invalid_arguments;
+
+    auto e = (s_ek != engine_kind::cpu) ? src_engine : dst_engine;
+
+    const primitive_attr_t dummy_attr;
+    if (attr == NULL)
+        attr = &dummy_attr;
+
+    for (auto r = e->get_reorder_implementation_list(); *r; ++r) {
+        if ((*r)(r_pd, e, attr, src_engine, src_md, dst_engine, dst_md)
+                == success) {
+            (*r_pd)->init_info();
+            (*r_pd)->init_scratchpad_md();
+            return success;
+        }
+    }
+    return unimplemented;
+}
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 85 - 0
thirdparty/oidn/mkl-dnn/src/common/reorder_pd.hpp

@@ -0,0 +1,85 @@
+/*******************************************************************************
+* Copyright 2016-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef REORDER_PD_HPP
+#define REORDER_PD_HPP
+
+#include <assert.h>
+
+#include "c_types_map.hpp"
+#include "primitive_attr.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct reorder_pd_t: public primitive_desc_t {
+    reorder_pd_t(engine_t *engine, const primitive_attr_t *attr,
+            engine_t *src_engine, const memory_desc_t *src_md,
+            engine_t *dst_engine, const memory_desc_t *dst_md)
+        : primitive_desc_t(engine, attr, primitive_kind::reorder)
+        , src_engine_(src_engine)
+        , dst_engine_(dst_engine)
+        , scratchpad_engine_(nullptr)
+        , src_md_(*src_md)
+        , dst_md_(*dst_md)
+    {}
+
+    virtual const op_desc_t *op_desc() const override { return nullptr; }
+    virtual void init_info() override { impl::init_info(this, this->info_); }
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (arg == MKLDNN_ARG_FROM)
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_TO)
+            return arg_usage_t::output;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *src_md(int index = 0) const override
+    { return index == 0 ? &src_md_ : nullptr; }
+    virtual const memory_desc_t *dst_md(int index = 0) const override
+    { return index == 0 ? &dst_md_ : nullptr; }
+
+    virtual int n_inputs() const override { return 1; }
+    virtual int n_outputs() const override { return 1; }
+
+    float alpha() const { return attr()->output_scales_.scales_[0]; }
+    float beta() const {
+        const int sum_idx = attr()->post_ops_.find(primitive_kind::sum);
+        return sum_idx == -1 ? 0 : attr()->post_ops_.entry_[sum_idx].sum.scale;
+    }
+    virtual mkldnn::impl::engine_t *scratchpad_engine() const override
+    { return scratchpad_engine_; }
+
+protected:
+    engine_t *src_engine_;
+    engine_t *dst_engine_;
+    engine_t *scratchpad_engine_;
+
+    memory_desc_t src_md_;
+    memory_desc_t dst_md_;
+};
+
+}
+}
+
+#endif
+
+// vim: et ts=4 sw=4 cindent cino^=l0,\:0,N-s

+ 400 - 0
thirdparty/oidn/mkl-dnn/src/common/rnn.cpp

@@ -0,0 +1,400 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "type_helpers.hpp"
+#include "utils.hpp"
+#include "cpu/gemm/os_blas.hpp"
+
+using namespace mkldnn::impl;
+using namespace mkldnn::impl::status;
+using namespace mkldnn::impl::types;
+using namespace mkldnn::impl::utils;
+
+namespace {
+memory_desc_t copy_maybe_null(const memory_desc_t *md) {
+    return md ? *md : zero_md();
+}
+
+rnn_desc_t zero_rnn_desc() {
+    auto rd = rnn_desc_t();
+    rd.src_layer_desc = zero_md();
+    rd.src_iter_desc = zero_md();
+    rd.weights_layer_desc = zero_md();
+    rd.weights_iter_desc = zero_md();
+    rd.bias_desc = zero_md();
+    rd.dst_layer_desc = zero_md();
+    rd.dst_iter_desc = zero_md();
+    rd.diff_src_layer_desc = zero_md();
+    rd.diff_src_iter_desc = zero_md();
+    rd.diff_weights_layer_desc = zero_md();
+    rd.diff_weights_iter_desc = zero_md();
+    rd.diff_bias_desc = zero_md();
+    rd.diff_dst_layer_desc = zero_md();
+    rd.diff_dst_iter_desc = zero_md();
+    return rd;
+}
+}
+
+/* Public C Api */
+
+status_t mkldnn_rnn_cell_desc_init(rnn_cell_desc_t *rnn_cell_desc,
+        mkldnn_alg_kind_t cell_kind, mkldnn_alg_kind_t act_f,
+        unsigned int flags, float alpha, float clipping) {
+    using namespace mkldnn::impl::alg_kind;
+
+    bool args_ok = true
+            && one_of(cell_kind, vanilla_rnn, vanilla_lstm, vanilla_gru,
+                    gru_linear_before_reset)
+            && IMPLICATION(cell_kind == vanilla_rnn,
+                    one_of(act_f, eltwise_relu, eltwise_tanh, eltwise_logistic));
+    if (!args_ok)
+        return invalid_arguments;
+
+    auto rcd = mkldnn_rnn_cell_desc_t();
+
+    rcd.cell_kind = cell_kind;
+    rcd.activation_kind = act_f;
+    rcd.flags = flags;
+    rcd.alpha = rcd.flags & mkldnn_rnn_cell_with_relu ? alpha : 0;
+    rcd.clipping = rcd.flags & mkldnn_rnn_cell_with_clipping ? clipping : 0;
+
+    *rnn_cell_desc = rcd;
+
+    return success;
+}
+
+int mkldnn_rnn_cell_get_gates_count(const rnn_cell_desc_t *rnn_cell_desc) {
+    switch (rnn_cell_desc->cell_kind) {
+    case mkldnn::impl::alg_kind::vanilla_rnn: return 1;
+    case mkldnn::impl::alg_kind::vanilla_gru: return 3;
+    case mkldnn::impl::alg_kind::gru_linear_before_reset: return 3;
+    case mkldnn::impl::alg_kind::vanilla_lstm: return 4;
+    default: assert(!"unknown cell kind"); return 0;
+    }
+    return 0;
+}
+
+int mkldnn_rnn_cell_get_states_count(const rnn_cell_desc_t *rnn_cell_desc) {
+    switch (rnn_cell_desc->cell_kind) {
+    case mkldnn::impl::alg_kind::vanilla_rnn: return 1;
+    case mkldnn::impl::alg_kind::vanilla_gru: return 1;
+    case mkldnn::impl::alg_kind::gru_linear_before_reset: return 1;
+    case mkldnn::impl::alg_kind::vanilla_lstm: return 2;
+    default: assert(!"unknown cell kind"); return 0;
+    }
+    return 0;
+}
+
+status_t check_data_type_consistency_fwd(const rnn_cell_desc_t *rnn_cell_desc,
+        prop_kind_t prop_kind, const memory_desc_t *src_layer_desc,
+        const memory_desc_t *src_iter_desc,
+        const memory_desc_t *weights_layer_desc,
+        const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
+        const memory_desc_t *dst_layer_desc,
+        const memory_desc_t *dst_iter_desc) {
+    using namespace data_type;
+    data_type_t src_layer_dt = src_layer_desc->data_type;
+    data_type_t dst_layer_dt = dst_layer_desc->data_type;
+    data_type_t weights_iter_dt = weights_iter_desc->data_type;
+    data_type_t weights_layer_dt = weights_layer_desc->data_type;
+
+    bool is_f32 = everyone_is(f32, src_layer_dt, dst_layer_dt, weights_iter_dt,
+                          weights_layer_dt)
+            && IMPLICATION(!is_zero_md(src_iter_desc),
+                          src_iter_desc->data_type == f32)
+            && IMPLICATION(!is_zero_md(dst_iter_desc),
+                          dst_iter_desc->data_type == f32)
+            && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
+
+#if USE_MKL_PACKED_GEMM
+    bool is_u8u8u8 = src_layer_dt == u8
+            && IMPLICATION(!is_zero_md(src_iter_desc),
+                             src_iter_desc->data_type == u8)
+            && IMPLICATION(!is_zero_md(dst_iter_desc),
+                             dst_iter_desc->data_type == u8)
+            && one_of(dst_layer_dt, u8, f32)
+            && everyone_is(s8, weights_iter_dt, weights_layer_dt)
+            && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
+
+    bool is_f32u8f32 = src_layer_dt == u8
+            && IMPLICATION(!is_zero_md(src_iter_desc),
+                               src_iter_desc->data_type == f32)
+            && IMPLICATION(!is_zero_md(dst_iter_desc),
+                               dst_iter_desc->data_type == f32)
+            && one_of(dst_layer_dt, u8, f32)
+            && everyone_is(s8, weights_iter_dt, weights_layer_dt)
+            && IMPLICATION(!is_zero_md(bias_desc), bias_desc->data_type == f32);
+
+    bool is_inference = prop_kind == prop_kind::forward_inference;
+    bool is_lstm = rnn_cell_desc->cell_kind == mkldnn_vanilla_lstm;
+
+    return (is_f32 || ((is_u8u8u8 || is_f32u8f32) && is_lstm && is_inference))
+            ? success
+            : unimplemented;
+#else
+    return is_f32 ? success : unimplemented;
+#endif
+}
+
+status_t check_dim_consistency(const rnn_cell_desc_t *rnn_cell_desc,
+        rnn_direction_t direction, int L, int D, int T, int N, int S, int G,
+        int SLC, int SIC, int DLC, int DIC, const memory_desc_t *src_layer_desc,
+        const memory_desc_t *src_iter_desc,
+        const memory_desc_t *weights_layer_desc,
+        const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
+        const memory_desc_t *dst_layer_desc,
+        const memory_desc_t *dst_iter_desc) {
+    bool args_ok;
+
+    // * algorithm specific
+    args_ok = true
+        && IMPLICATION(rnn_cell_desc->cell_kind == alg_kind::vanilla_gru,
+                       DIC == SIC);
+    if (!args_ok) return invalid_arguments;
+    int extra_bias =
+            rnn_cell_desc->cell_kind == alg_kind::gru_linear_before_reset;
+
+    // * on num layers
+    args_ok = true
+        && L == weights_layer_desc->dims[0]
+        && L == weights_iter_desc->dims[0]
+        && IMPLICATION(!is_zero_md(bias_desc), L == bias_desc->dims[0])
+        && IMPLICATION(!is_zero_md(src_iter_desc), L == src_iter_desc->dims[0])
+        && IMPLICATION(!is_zero_md(dst_iter_desc), L == dst_iter_desc->dims[0]);
+    if (!args_ok) return invalid_arguments;
+
+    // * on num directions
+    args_ok = true
+        && D == weights_layer_desc->dims[1]
+        && D == weights_iter_desc->dims[1]
+        && IMPLICATION(!is_zero_md(bias_desc), D == bias_desc->dims[1])
+        && IMPLICATION(!is_zero_md(src_iter_desc), D == src_iter_desc->dims[1])
+        && IMPLICATION(!is_zero_md(dst_iter_desc), D == dst_iter_desc->dims[1]);
+    if (!args_ok) return invalid_arguments;
+
+    // * on num iterations
+    args_ok = true
+        && T == src_layer_desc->dims[0]
+        && T == dst_layer_desc->dims[0];
+    if (!args_ok) return invalid_arguments;
+
+    // * on mb
+    args_ok = true
+        && N == src_layer_desc->dims[1]
+        && N == dst_layer_desc->dims[1]
+        && IMPLICATION(!is_zero_md(src_iter_desc), N == src_iter_desc->dims[3])
+        && IMPLICATION(!is_zero_md(dst_iter_desc), N == dst_iter_desc->dims[3]);
+    if (!args_ok) return invalid_arguments;
+
+    // * on num gates
+    args_ok = true
+        && G == mkldnn_rnn_cell_get_gates_count(rnn_cell_desc)
+        && G == weights_layer_desc->dims[3]
+        && G == weights_iter_desc->dims[3]
+        && IMPLICATION(!is_zero_md(bias_desc),
+                G + extra_bias == bias_desc->dims[2]);
+    if (!args_ok) return invalid_arguments;
+
+    // * on num states
+    args_ok = true
+        && S == mkldnn_rnn_cell_get_states_count(rnn_cell_desc)
+        && IMPLICATION(!is_zero_md(src_iter_desc), S == src_iter_desc->dims[2])
+        && IMPLICATION(!is_zero_md(dst_iter_desc), S == dst_iter_desc->dims[2]);
+    if (!args_ok) return invalid_arguments;
+
+    // * on slc
+    args_ok = true
+        && SLC == weights_layer_desc->dims[2]
+        && SLC == src_layer_desc->dims[2];
+    if (!args_ok) return invalid_arguments;
+
+    // * on sic
+    args_ok = true
+        && SIC == weights_iter_desc->dims[2]
+        && IMPLICATION(!is_zero_md(src_iter_desc),
+                SIC == src_iter_desc->dims[4]);
+    if (!args_ok) return invalid_arguments;
+
+    // * on dlc
+    int dlc_multiplier = (direction == mkldnn_bidirectional_concat) ? 2 : 1;
+    args_ok = true
+        && DLC == dlc_multiplier * DIC
+        && DLC == dst_layer_desc->dims[2];
+    if (!args_ok) return invalid_arguments;
+
+    // * on dic
+    args_ok = true
+        && DIC == weights_layer_desc->dims[4]
+        && DIC == weights_iter_desc->dims[4]
+        && IMPLICATION(!is_zero_md(bias_desc), DIC == bias_desc->dims[3])
+        && IMPLICATION(!is_zero_md(dst_iter_desc),
+                DIC == dst_iter_desc->dims[4]);
+    if (!args_ok) return invalid_arguments;
+
+    // * unrolling/fusion conditions
+    args_ok = true
+        && IMPLICATION(L > 1, (dlc_multiplier * SLC) == DLC)
+        && IMPLICATION(T > 1, SIC == DIC);
+    if (!args_ok) return invalid_arguments;
+
+    return success;
+}
+
+status_t MKLDNN_API mkldnn_rnn_forward_desc_init(mkldnn_rnn_desc_t *rnn_desc,
+        prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc,
+        const rnn_direction_t direction, const memory_desc_t *src_layer_desc,
+        const memory_desc_t *src_iter_desc,
+        const memory_desc_t *weights_layer_desc,
+        const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
+        const memory_desc_t *dst_layer_desc,
+        const memory_desc_t *dst_iter_desc) {
+    bool args_ok = true && rnn_cell_desc != nullptr
+            && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
+                       dst_layer_desc);
+    if (!args_ok) return invalid_arguments;
+
+    //check dimensions consistency
+    int L = weights_layer_desc->dims[0];
+    int T = src_layer_desc->dims[0];
+    int N = src_layer_desc->dims[1];
+    const int D = one_of(direction, mkldnn_unidirectional_left2right,
+                          mkldnn_unidirectional_right2left) ?
+            1 :
+            2;
+    int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc);
+    int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc);
+    int SLC = src_layer_desc->dims[2];
+    int SIC = weights_iter_desc->dims[2];
+    int DLC = dst_layer_desc->dims[2];
+    int DIC = weights_layer_desc->dims[4];
+
+    CHECK(check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
+            G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc,
+            weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc,
+            dst_iter_desc));
+
+    CHECK(check_data_type_consistency_fwd(rnn_cell_desc, prop_kind,
+            src_layer_desc, src_iter_desc, weights_layer_desc,
+            weights_iter_desc, bias_desc, dst_layer_desc, dst_iter_desc));
+
+    // Create the descriptor
+    mkldnn_rnn_desc_t rd = zero_rnn_desc();
+
+    rd.primitive_kind = primitive_kind::rnn;
+    rd.prop_kind = prop_kind;
+    rd.cell_desc = *rnn_cell_desc;
+    rd.direction = direction;
+    rd.src_layer_desc = copy_maybe_null(src_layer_desc);
+    rd.src_iter_desc = copy_maybe_null(src_iter_desc);
+    rd.weights_layer_desc = copy_maybe_null(weights_layer_desc);
+    rd.weights_iter_desc = copy_maybe_null(weights_iter_desc);
+    rd.bias_desc = copy_maybe_null(bias_desc);
+    rd.dst_layer_desc = copy_maybe_null(dst_layer_desc);
+    rd.dst_iter_desc = copy_maybe_null(dst_iter_desc);
+
+    *rnn_desc = rd;
+
+    return success;
+}
+
+status_t MKLDNN_API mkldnn_rnn_backward_desc_init(mkldnn_rnn_desc_t *rnn_desc,
+        prop_kind_t prop_kind, const rnn_cell_desc_t *rnn_cell_desc,
+        const rnn_direction_t direction, const memory_desc_t *src_layer_desc,
+        const memory_desc_t *src_iter_desc,
+        const memory_desc_t *weights_layer_desc,
+        const memory_desc_t *weights_iter_desc, const memory_desc_t *bias_desc,
+        const memory_desc_t *dst_layer_desc, const memory_desc_t *dst_iter_desc,
+        const memory_desc_t *diff_src_layer_desc,
+        const memory_desc_t *diff_src_iter_desc,
+        const memory_desc_t *diff_weights_layer_desc,
+        const memory_desc_t *diff_weights_iter_desc,
+        const memory_desc_t *diff_bias_desc,
+        const memory_desc_t *diff_dst_layer_desc,
+        const memory_desc_t *diff_dst_iter_desc) {
+    bool args_ok = true
+            && !any_null(src_layer_desc, weights_layer_desc, weights_iter_desc,
+                       dst_layer_desc, diff_src_layer_desc,
+                       diff_weights_layer_desc, diff_weights_iter_desc,
+                       diff_dst_layer_desc);
+    if (!args_ok)
+        return invalid_arguments;
+
+    auto xnor_md = [=](const memory_desc_t *a_md, const memory_desc_t *b_md) {
+        return is_zero_md(a_md) == is_zero_md(b_md);
+    };
+
+    args_ok = args_ok && xnor_md(bias_desc, diff_bias_desc)
+            && xnor_md(dst_iter_desc, diff_dst_iter_desc)
+            && xnor_md(src_iter_desc, diff_src_iter_desc);
+    if (!args_ok)
+        return invalid_arguments;
+
+    //check dimensions consistency
+    int L = weights_layer_desc->dims[0];
+    int T = src_layer_desc->dims[0];
+    int N = src_layer_desc->dims[1];
+    const int D = one_of(direction, mkldnn_unidirectional_left2right,
+                          mkldnn_unidirectional_right2left) ?
+            1 :
+            2;
+    int G = mkldnn_rnn_cell_get_gates_count(rnn_cell_desc);
+    int S = mkldnn_rnn_cell_get_states_count(rnn_cell_desc);
+    int SLC = src_layer_desc->dims[2];
+    int SIC = weights_iter_desc->dims[2];
+    int DLC = dst_layer_desc->dims[2];
+    int DIC = weights_layer_desc->dims[4];
+
+    status_t st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
+            G, SLC, SIC, DLC, DIC, src_layer_desc, src_iter_desc,
+            weights_layer_desc, weights_iter_desc, bias_desc, dst_layer_desc,
+            dst_iter_desc);
+    if (st != success) return st;
+
+    st = check_dim_consistency(rnn_cell_desc, direction, L, D, T, N, S,
+            G, SLC, SIC, DLC, DIC, diff_src_layer_desc, diff_src_iter_desc,
+            diff_weights_layer_desc, diff_weights_iter_desc, diff_bias_desc,
+            diff_dst_layer_desc, diff_dst_iter_desc);
+    if (st != success) return st;
+
+    mkldnn_rnn_desc_t rd = zero_rnn_desc();
+
+    rd.primitive_kind = primitive_kind::rnn;
+    rd.prop_kind = prop_kind;
+    rd.cell_desc = *rnn_cell_desc;
+    rd.direction = direction;
+
+    rd.src_layer_desc = copy_maybe_null(src_layer_desc);
+    rd.src_iter_desc = copy_maybe_null(src_iter_desc);
+    rd.weights_layer_desc = copy_maybe_null(weights_layer_desc);
+    rd.weights_iter_desc = copy_maybe_null(weights_iter_desc);
+    rd.bias_desc = copy_maybe_null(bias_desc);
+    rd.dst_layer_desc = copy_maybe_null(dst_layer_desc);
+    rd.dst_iter_desc = copy_maybe_null(dst_iter_desc);
+    rd.diff_src_layer_desc = copy_maybe_null(diff_src_layer_desc);
+    rd.diff_src_iter_desc = copy_maybe_null(diff_src_iter_desc);
+    rd.diff_weights_layer_desc = copy_maybe_null(diff_weights_layer_desc);
+    rd.diff_weights_iter_desc = copy_maybe_null(diff_weights_iter_desc);
+    rd.diff_bias_desc = copy_maybe_null(diff_bias_desc);
+    rd.diff_dst_layer_desc = copy_maybe_null(diff_dst_layer_desc);
+    rd.diff_dst_iter_desc = copy_maybe_null(diff_dst_iter_desc);
+
+    *rnn_desc = rd;
+
+    return success;
+}

+ 280 - 0
thirdparty/oidn/mkl-dnn/src/common/rnn_pd.hpp

@@ -0,0 +1,280 @@
+/*******************************************************************************
+* Copyright 2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#ifndef RNN_PD_HPP
+#define RNN_PD_HPP
+
+#include "mkldnn.h"
+
+#include "c_types_map.hpp"
+#include "primitive_desc.hpp"
+#include "type_helpers.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+struct rnn_fwd_pd_t;
+
+struct rnn_pd_t : public primitive_desc_t {
+    static constexpr auto base_pkind = primitive_kind::rnn;
+
+    rnn_pd_t(engine_t *engine,
+            const rnn_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const rnn_fwd_pd_t *hint_fwd_pd)
+        : primitive_desc_t(engine, attr, base_pkind)
+        , desc_(*adesc)
+        , hint_fwd_pd_(hint_fwd_pd)
+        , src_layer_md_(desc_.src_layer_desc)
+        , src_iter_md_(desc_.src_iter_desc)
+        , weights_layer_md_(desc_.weights_layer_desc)
+        , weights_iter_md_(desc_.weights_iter_desc)
+        , bias_md_(desc_.bias_desc)
+        , dst_layer_md_(desc_.dst_layer_desc)
+        , dst_iter_md_(desc_.dst_iter_desc)
+        , ws_md_()
+    {}
+
+    const rnn_desc_t *desc() const { return &desc_; }
+    virtual const op_desc_t *op_desc() const override
+    { return reinterpret_cast<const op_desc_t *>(this->desc()); }
+    virtual void init_info() override { impl::init_info(this, this->info_); }
+
+    virtual status_t query(query_t what, int idx, void *result) const override {
+        switch (what) {
+        case query::rnn_d: *(const rnn_desc_t **)result = desc(); break;
+        default: return primitive_desc_t::query(what, idx, result);
+        }
+        return status::success;
+    }
+
+    virtual const memory_desc_t *src_md(int index = 0) const override {
+        if (index == 0) return &src_layer_md_;
+        if (index == 1 && with_src_iter()) return &src_iter_md_;
+        return nullptr;
+    }
+    virtual const memory_desc_t *weights_md(int index = 0) const override {
+        if (index == 0) return &weights_layer_md_;
+        if (index == 1) return &weights_iter_md_;
+        if (index == 2 && with_bias()) return &bias_md_;
+        return nullptr;
+    }
+    virtual const memory_desc_t *dst_md(int index = 0) const override {
+        if (index == 0) return &dst_layer_md_;
+        if (index == 1 && with_dst_iter()) return &dst_iter_md_;
+        return nullptr;
+    }
+    virtual const memory_desc_t *workspace_md(int index = 0) const override
+    { return index == 0 && !types::is_zero_md(&ws_md_) ? &ws_md_ : nullptr; }
+
+    /* common pooling aux functions */
+
+    bool is_training() const {
+        return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+                prop_kind::backward);
+    }
+
+    bool is_fwd() const {
+        return utils::one_of(desc_.prop_kind, prop_kind::forward_training,
+                prop_kind::forward_inference);
+    }
+
+    dim_t T() const { return desc_.src_layer_desc.dims[0]; }
+    dim_t MB() const { return desc_.src_layer_desc.dims[1]; }
+
+    dim_t L() const { return desc_.weights_layer_desc.dims[0]; }
+    dim_t D() const { return desc_.weights_layer_desc.dims[1]; }
+
+    dim_t SIC() const { return desc_.weights_iter_desc.dims[2]; }
+
+    dim_t SLC() const { return desc_.weights_layer_desc.dims[2]; }
+    dim_t G() const { return desc_.weights_layer_desc.dims[3]; }
+    dim_t DIC() const { return desc_.weights_layer_desc.dims[4]; }
+
+    dim_t DLC() const { return desc_.dst_layer_desc.dims[2]; }
+
+    bool with_bias() const
+    { return !memory_desc_wrapper(desc_.bias_desc).is_zero(); }
+
+    bool with_src_iter() const
+    { return !(memory_desc_wrapper(desc_.src_iter_desc).is_zero()); }
+
+    bool with_dst_iter() const
+    { return !memory_desc_wrapper(desc_.dst_iter_desc).is_zero(); }
+
+    mkldnn::impl::alg_kind_t cell_kind() const
+    { return desc_.cell_desc.cell_kind; }
+    mkldnn::impl::alg_kind_t activation_kind() const
+    { return desc_.cell_desc.activation_kind; }
+
+    bool is_lbr() const
+    { return cell_kind() == mkldnn_gru_linear_before_reset; }
+
+    mkldnn_rnn_direction_t direction() const { return desc_.direction; }
+
+protected:
+    rnn_desc_t desc_;
+    const rnn_fwd_pd_t *hint_fwd_pd_;
+
+    memory_desc_t src_layer_md_;
+    memory_desc_t src_iter_md_;
+    memory_desc_t weights_layer_md_;
+    memory_desc_t weights_iter_md_;
+    memory_desc_t bias_md_;
+    memory_desc_t dst_layer_md_;
+    memory_desc_t dst_iter_md_;
+
+    memory_desc_t ws_md_;
+};
+
+struct rnn_fwd_pd_t: public rnn_pd_t {
+    typedef rnn_fwd_pd_t base_class;
+    typedef rnn_fwd_pd_t hint_class;
+
+    rnn_fwd_pd_t(engine_t *engine,
+            const rnn_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const rnn_fwd_pd_t *hint_fwd_pd)
+        : rnn_pd_t(engine, adesc, attr, hint_fwd_pd)
+    {}
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (arg == MKLDNN_ARG_SRC_LAYER)
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_SRC_ITER && with_src_iter())
+            return arg_usage_t::input;
+
+        if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER,
+                    MKLDNN_ARG_WEIGHTS_ITER))
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_BIAS && with_bias())
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_DST_LAYER)
+            return arg_usage_t::output;
+
+        if (arg == MKLDNN_ARG_DST_ITER && with_dst_iter())
+            return arg_usage_t::output;
+
+        if (arg == MKLDNN_ARG_WORKSPACE && is_training())
+            return arg_usage_t::output;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual int n_inputs() const override
+    { return 3 + with_bias() + with_src_iter(); }
+    virtual int n_outputs() const override
+    { return 1 + with_dst_iter() + is_training(); }
+};
+
+struct rnn_bwd_pd_t : public rnn_pd_t {
+    typedef rnn_bwd_pd_t base_class;
+    typedef rnn_fwd_pd_t hint_class;
+
+    rnn_bwd_pd_t(engine_t *engine,
+            const rnn_desc_t *adesc,
+            const primitive_attr_t *attr,
+            const rnn_fwd_pd_t *hint_fwd_pd)
+        : rnn_pd_t(engine, adesc, attr, hint_fwd_pd)
+        , diff_src_layer_md_(desc_.diff_src_layer_desc)
+        , diff_src_iter_md_(desc_.diff_src_iter_desc)
+        , diff_weights_layer_md_(desc_.diff_weights_layer_desc)
+        , diff_weights_iter_md_(desc_.diff_weights_iter_desc)
+        , diff_bias_md_(desc_.diff_bias_desc)
+        , diff_dst_layer_md_(desc_.diff_dst_layer_desc)
+        , diff_dst_iter_md_(desc_.diff_dst_iter_desc)
+    {}
+
+    virtual arg_usage_t arg_usage(primitive_arg_index_t arg) const override {
+        if (utils::one_of(arg, MKLDNN_ARG_SRC_LAYER, MKLDNN_ARG_DST_LAYER,
+                    MKLDNN_ARG_DIFF_DST_LAYER))
+            return arg_usage_t::input;
+
+        if (with_src_iter()) {
+            if (arg == MKLDNN_ARG_SRC_ITER)
+                return arg_usage_t::input;
+
+            if (arg == MKLDNN_ARG_DIFF_SRC_ITER)
+                return arg_usage_t::output;
+        }
+
+        if (utils::one_of(arg, MKLDNN_ARG_WEIGHTS_LAYER,
+                    MKLDNN_ARG_WEIGHTS_ITER))
+            return arg_usage_t::input;
+
+        if (with_bias()) {
+            if (arg == MKLDNN_ARG_BIAS)
+                return arg_usage_t::input;
+
+            if (arg == MKLDNN_ARG_DIFF_BIAS)
+                return arg_usage_t::output;
+        }
+
+        if (utils::one_of(arg, MKLDNN_ARG_DST_ITER, MKLDNN_ARG_DIFF_DST_ITER)
+                && with_dst_iter())
+            return arg_usage_t::input;
+
+        if (arg == MKLDNN_ARG_WORKSPACE)
+            return arg_usage_t::input;
+
+        if (utils::one_of(arg, MKLDNN_ARG_DIFF_SRC_LAYER,
+                    MKLDNN_ARG_DIFF_WEIGHTS_LAYER,
+                    MKLDNN_ARG_DIFF_WEIGHTS_ITER))
+            return arg_usage_t::output;
+
+        return primitive_desc_t::arg_usage(arg);
+    }
+
+    virtual const memory_desc_t *diff_src_md(int index = 0) const override {
+        if (index == 0) return &diff_src_layer_md_;
+        if (index == 1 && with_src_iter()) return &diff_src_iter_md_;
+        return nullptr;
+    }
+    virtual const memory_desc_t *diff_weights_md(
+            int index = 0) const override {
+        if (index == 0) return &diff_weights_layer_md_;
+        if (index == 1) return &diff_weights_iter_md_;
+        if (index == 2 && with_bias()) return &diff_bias_md_;
+        return nullptr;
+    }
+    virtual const memory_desc_t *diff_dst_md(int index = 0) const override {
+        if (index == 0) return &diff_dst_layer_md_;
+        if (index == 1 && with_dst_iter()) return &diff_dst_iter_md_;
+        return nullptr;
+    }
+
+    virtual int n_inputs() const override
+    { return 6 + with_src_iter() + with_bias() + 2 * with_dst_iter(); }
+    virtual int n_outputs() const override
+    { return 3 + with_src_iter() + with_bias(); }
+
+protected:
+    memory_desc_t diff_src_layer_md_;
+    memory_desc_t diff_src_iter_md_;
+    memory_desc_t diff_weights_layer_md_;
+    memory_desc_t diff_weights_iter_md_;
+    memory_desc_t diff_bias_md_;
+    memory_desc_t diff_dst_layer_md_;
+    memory_desc_t diff_dst_iter_md_;
+};
+
+}
+}
+
+#endif

+ 112 - 0
thirdparty/oidn/mkl-dnn/src/common/scratchpad.cpp

@@ -0,0 +1,112 @@
+/*******************************************************************************
+* Copyright 2017-2018 Intel Corporation
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+*     http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*******************************************************************************/
+
+#include "mkldnn_thread.hpp"
+#include "utils.hpp"
+
+#include "scratchpad.hpp"
+
+namespace mkldnn {
+namespace impl {
+
+/* Allocating memory buffers on a page boundary to reduce TLB/page misses */
+const size_t page_size = 2097152;
+
+/*
+  Implementation of the scratchpad_t interface that is compatible with
+  a concurrent execution
+*/
+struct concurent_scratchpad_t : public scratchpad_t {
+    concurent_scratchpad_t(size_t size) {
+        size_ = size;
+        scratchpad_ = (char *) malloc(size, page_size);
+        assert(scratchpad_ != nullptr);
+    }
+
+    ~concurent_scratchpad_t() {
+        free(scratchpad_);
+    }
+
+    virtual char *get() const {
+        return scratchpad_;
+    }
+
+private:
+    char *scratchpad_;
+    size_t size_;
+};
+
+/*
+  Implementation of the scratchpad_t interface that uses a global
+  scratchpad
+*/
+
+struct global_scratchpad_t : public scratchpad_t {
+    global_scratchpad_t(size_t size) {
+        if (size > size_) {
+            if (scratchpad_ != nullptr) free(scratchpad_);
+            size_ = size;
+            scratchpad_ = (char *) malloc(size, page_size);
+            assert(scratchpad_ != nullptr);
+        }
+        reference_count_++;
+    }
+
+    ~global_scratchpad_t() {
+        reference_count_--;
+        if (reference_count_ == 0) {
+            free(scratchpad_);
+            scratchpad_ = nullptr;
+            size_ = 0;
+        }
+    }
+
+    virtual char *get() const {
+        return scratchpad_;
+    }
+
+private:
+    /*
+      Using thread-local here is unnecessary and even buggy! All threads
+      actually share the same scratchpad, which is created and queried only
+      on the main thread. If the scratchpad is queried on some thread other
+      than the one it was created on (e.g. the application calls the API from
+      multiple threads), thread-local causes a segfault because the scratchpad
+      is uninitialized on the current thread.
+    */
+    /*thread_local*/ static char *scratchpad_;
+    /*thread_local*/ static size_t size_;
+    /*thread_local*/ static unsigned int reference_count_;
+};
+
+/*thread_local*/ char *global_scratchpad_t::scratchpad_ = nullptr;
+/*thread_local*/ size_t global_scratchpad_t::size_ = 0;
+/*thread_local*/ unsigned int global_scratchpad_t::reference_count_ = 0;
+
+
+/*
+   Scratchpad creation routine
+*/
+scratchpad_t *create_scratchpad(size_t size) {
+#ifndef MKLDNN_ENABLE_CONCURRENT_EXEC
+    return new global_scratchpad_t(size);
+#else
+    return new concurent_scratchpad_t(size);
+#endif
+}
+
+}
+}

Some files were not shown because too many files changed in this diff