WebGPUTextureMipmapUtils.js 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. import { GPUTextureViewDimension, GPUIndexFormat, GPUFilterMode, GPUPrimitiveTopology, GPULoadOp, GPUStoreOp } from './WebGPUConstants.js';
  2. class WebGPUTextureMipmapUtils {
  3. constructor( device ) {
  4. this.device = device;
  5. const mipmapVertexSource = `
  6. struct VarysStruct {
  7. @builtin( position ) Position: vec4<f32>,
  8. @location( 0 ) vTex : vec2<f32>
  9. };
  10. @vertex
  11. fn main( @builtin( vertex_index ) vertexIndex : u32 ) -> VarysStruct {
  12. var Varys : VarysStruct;
  13. var pos = array< vec2<f32>, 4 >(
  14. vec2<f32>( -1.0, 1.0 ),
  15. vec2<f32>( 1.0, 1.0 ),
  16. vec2<f32>( -1.0, -1.0 ),
  17. vec2<f32>( 1.0, -1.0 )
  18. );
  19. var tex = array< vec2<f32>, 4 >(
  20. vec2<f32>( 0.0, 0.0 ),
  21. vec2<f32>( 1.0, 0.0 ),
  22. vec2<f32>( 0.0, 1.0 ),
  23. vec2<f32>( 1.0, 1.0 )
  24. );
  25. Varys.vTex = tex[ vertexIndex ];
  26. Varys.Position = vec4<f32>( pos[ vertexIndex ], 0.0, 1.0 );
  27. return Varys;
  28. }
  29. `;
  30. const mipmapFragmentSource = `
  31. @group( 0 ) @binding( 0 )
  32. var imgSampler : sampler;
  33. @group( 0 ) @binding( 1 )
  34. var img : texture_2d<f32>;
  35. @fragment
  36. fn main( @location( 0 ) vTex : vec2<f32> ) -> @location( 0 ) vec4<f32> {
  37. return textureSample( img, imgSampler, vTex );
  38. }
  39. `;
  40. this.sampler = device.createSampler( { minFilter: GPUFilterMode.Linear } );
  41. // We'll need a new pipeline for every texture format used.
  42. this.pipelines = {};
  43. this.mipmapVertexShaderModule = device.createShaderModule( {
  44. label: 'mipmapVertex',
  45. code: mipmapVertexSource
  46. } );
  47. this.mipmapFragmentShaderModule = device.createShaderModule( {
  48. label: 'mipmapFragment',
  49. code: mipmapFragmentSource
  50. } );
  51. }
  52. getMipmapPipeline( format ) {
  53. let pipeline = this.pipelines[ format ];
  54. if ( pipeline === undefined ) {
  55. pipeline = this.device.createRenderPipeline( {
  56. vertex: {
  57. module: this.mipmapVertexShaderModule,
  58. entryPoint: 'main'
  59. },
  60. fragment: {
  61. module: this.mipmapFragmentShaderModule,
  62. entryPoint: 'main',
  63. targets: [ { format } ]
  64. },
  65. primitive: {
  66. topology: GPUPrimitiveTopology.TriangleStrip,
  67. stripIndexFormat: GPUIndexFormat.Uint32
  68. },
  69. layout: 'auto'
  70. } );
  71. this.pipelines[ format ] = pipeline;
  72. }
  73. return pipeline;
  74. }
  75. generateMipmaps( textureGPU, textureGPUDescriptor, baseArrayLayer = 0 ) {
  76. const pipeline = this.getMipmapPipeline( textureGPUDescriptor.format );
  77. const commandEncoder = this.device.createCommandEncoder( {} );
  78. const bindGroupLayout = pipeline.getBindGroupLayout( 0 ); // @TODO: Consider making this static.
  79. let srcView = textureGPU.createView( {
  80. baseMipLevel: 0,
  81. mipLevelCount: 1,
  82. dimension: GPUTextureViewDimension.TwoD,
  83. baseArrayLayer
  84. } );
  85. for ( let i = 1; i < textureGPUDescriptor.mipLevelCount; i ++ ) {
  86. const dstView = textureGPU.createView( {
  87. baseMipLevel: i,
  88. mipLevelCount: 1,
  89. dimension: GPUTextureViewDimension.TwoD,
  90. baseArrayLayer
  91. } );
  92. const passEncoder = commandEncoder.beginRenderPass( {
  93. colorAttachments: [ {
  94. view: dstView,
  95. loadOp: GPULoadOp.Clear,
  96. storeOp: GPUStoreOp.Store,
  97. clearValue: [ 0, 0, 0, 0 ]
  98. } ]
  99. } );
  100. const bindGroup = this.device.createBindGroup( {
  101. layout: bindGroupLayout,
  102. entries: [ {
  103. binding: 0,
  104. resource: this.sampler
  105. }, {
  106. binding: 1,
  107. resource: srcView
  108. } ]
  109. } );
  110. passEncoder.setPipeline( pipeline );
  111. passEncoder.setBindGroup( 0, bindGroup );
  112. passEncoder.draw( 4, 1, 0, 0 );
  113. passEncoder.end();
  114. srcView = dstView;
  115. }
  116. this.device.queue.submit( [ commandEncoder.finish() ] );
  117. }
  118. }
  119. export default WebGPUTextureMipmapUtils;