connection.js 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. const { lookup } = require('lookup.js')
  2. const { createClient } = require('tcp.js')
  3. const { md5AuthMessage, syncMessage, startupMessage, createParser, getPGError, constants } = require('pg.js')
  4. const { html } = just.library('html.so', 'html')
  5. const {
  6. AuthenticationOk,
  7. ErrorResponse,
  8. RowDescription,
  9. CommandComplete,
  10. ParseComplete,
  11. NoData,
  12. ReadyForQuery
  13. } = constants.messageTypes
  14. const { INT4OID } = constants.fieldTypes
  15. function getMessageName (type) {
  16. const code = String.fromCharCode(type)
  17. let name = ''
  18. Object.keys(constants.messageTypes).some(key => {
  19. if (constants.messageTypes[key] === type) {
  20. name = key
  21. return true
  22. }
  23. })
  24. return { type, code, name }
  25. }
  26. function setupSocket (sock, config) {
  27. function compile (query, onComplete) {
  28. const buf = new ArrayBuffer(4096)
  29. const dv = new DataView(buf)
  30. let len = 0
  31. const fun = {
  32. dv,
  33. size: 0,
  34. described: false,
  35. buffer: new ArrayBuffer(65536),
  36. messages: {
  37. prepare: { start: 0, len: 0 },
  38. bind: { start: 0, len: 0 },
  39. exec: { start: 0, len: 0 },
  40. describe: { start: 0, len: 0 },
  41. flush: { start: 0, len: 0 },
  42. sync: { start: 0, len: 0 }
  43. },
  44. paramStart: 0
  45. }
  46. fun.buffer.offset = 0
  47. const { name, sql, params = [], formats = [], fields = [], portal = '', maxRows = 0 } = query
  48. fun.call = (onComplete, syncIt = true, flushIt = false) => {
  49. let off = fun.paramStart
  50. // 32 bit integers only for now
  51. for (let i = 0; i < params.length; i++) {
  52. off += 4
  53. dv.setUint32(off, params[i])
  54. off += 4
  55. }
  56. const { bind, exec, flush, sync } = fun.messages
  57. off = bind.start
  58. let len = 0
  59. if (flushIt) {
  60. len = flush.start + flush.len - off
  61. } else if (syncIt) {
  62. len = sync.start + sync.len - off
  63. } else {
  64. len = exec.start + exec.len - off
  65. }
  66. const r = sock.write(buf, len, off)
  67. if (r < len) {
  68. just.error('short write')
  69. }
  70. callbacks.push(onComplete)
  71. }
  72. fun.append = (onComplete, syncIt = true, flushIt = false) => {
  73. let off = fun.paramStart
  74. // 32 bit integers only for now
  75. for (let i = 0; i < params.length; i++) {
  76. off += 4
  77. dv.setUint32(off, params[i])
  78. off += 4
  79. }
  80. const { bind, exec, flush, sync } = fun.messages
  81. off = bind.start
  82. let len = 0
  83. if (flushIt) {
  84. len = flush.start + flush.len - off
  85. } else if (syncIt) {
  86. len = sync.start + sync.len - off
  87. } else {
  88. len = exec.start + exec.len - off
  89. }
  90. fun.buffer.offset += fun.buffer.copyFrom(buf, fun.buffer.offset, len, off)
  91. callbacks.push(onComplete)
  92. }
  93. fun.send = () => {
  94. const r = sock.write(fun.buffer, fun.buffer.offset, 0)
  95. if (r < len) {
  96. just.error('short write')
  97. }
  98. fun.buffer.offset = 0
  99. }
  100. fun.bind = (flushIt = true, onComplete) => {
  101. const { bind, flush } = fun.messages
  102. sock.write(buf, bind.len, bind.start)
  103. if (flushIt) {
  104. sock.write(buf, flush.len, flush.start)
  105. }
  106. callbacks.push(onComplete)
  107. }
  108. fun.exec = (flushIt = true, onComplete) => {
  109. const { exec, flush } = fun.messages
  110. sock.write(buf, exec.len, exec.start)
  111. if (flushIt) {
  112. sock.write(buf, flush.len, flush.start)
  113. }
  114. callbacks.push(onComplete)
  115. }
  116. fun.prepare = (flushIt = true, onComplete) => {
  117. const { prepare, flush } = fun.messages
  118. sock.write(buf, prepare.len, prepare.start)
  119. if (flushIt) {
  120. sock.write(buf, flush.len, flush.start)
  121. }
  122. callbacks.push(onComplete)
  123. }
  124. fun.describe = (flushIt = true, onComplete) => {
  125. const { describe, flush } = fun.messages
  126. sock.write(buf, describe.len, describe.start)
  127. if (flushIt) {
  128. sock.write(buf, flush.len, flush.start)
  129. }
  130. callbacks.push(onComplete)
  131. }
  132. let off = 0
  133. // Prepare Message
  134. fun.messages.prepare.start = off
  135. len = 1 + 4 + sql.length + 1 + name.length + 1 + 2 + (formats.length * 4)
  136. dv.setUint8(off++, 80) // 'P'
  137. dv.setUint32(off, len - 1)
  138. off += 4
  139. off += buf.writeString(name, off)
  140. dv.setUint8(off++, 0)
  141. off += buf.writeString(sql, off)
  142. dv.setUint8(off++, 0)
  143. dv.setUint16(off, formats.length)
  144. off += 2
  145. for (let i = 0; i < formats.length; i++) {
  146. dv.setUint32(off, formats[i].oid)
  147. off += 4
  148. }
  149. fun.messages.prepare.len = off - fun.messages.prepare.start
  150. // Describe Message
  151. fun.messages.describe.start = off
  152. len = 7 + name.length
  153. dv.setUint8(off++, 68) // 'D'
  154. dv.setUint32(off, len - 1)
  155. off += 4
  156. dv.setUint8(off++, 83) // 'S'
  157. off += buf.writeString(name, off)
  158. dv.setUint8(off++, 0)
  159. fun.messages.describe.len = off - fun.messages.describe.start
  160. // Bind Message
  161. fun.messages.bind.start = off
  162. dv.setUint8(off++, 66) // 'B'
  163. off += 4 // length - will be filled in later
  164. if (portal.length) {
  165. off += buf.writeString(portal, off)
  166. dv.setUint8(off++, 0)
  167. off += buf.writeString(name, off)
  168. dv.setUint8(off++, 0)
  169. } else {
  170. dv.setUint8(off++, 0)
  171. off += buf.writeString(name, off)
  172. dv.setUint8(off++, 0)
  173. }
  174. dv.setUint16(off, formats.length || 0)
  175. off += 2
  176. for (let i = 0; i < formats.length; i++) {
  177. dv.setUint16(off, formats[i].format)
  178. off += 2
  179. }
  180. dv.setUint16(off, params.length || 0)
  181. off += 2
  182. fun.paramStart = off
  183. for (let i = 0; i < params.length; i++) {
  184. if ((formats[i] || formats[0]).format === 1) {
  185. dv.setUint32(off, 4)
  186. off += 4
  187. dv.setUint32(off, params[i])
  188. off += 4
  189. } else {
  190. const paramString = params[i].toString()
  191. dv.setUint32(off, paramString.length)
  192. off += 4
  193. off += buf.writeString(paramString, off)
  194. }
  195. }
  196. dv.setUint16(off, fields.length)
  197. off += 2
  198. for (let i = 0; i < fields.length; i++) {
  199. dv.setUint16(off, fields[i].format)
  200. off += 2
  201. }
  202. fun.messages.bind.len = off - fun.messages.bind.start
  203. dv.setUint32(fun.messages.bind.start + 1, fun.messages.bind.len - 1)
  204. // Exec Message
  205. fun.messages.exec.start = off
  206. len = 6 + portal.length + 4
  207. dv.setUint8(off++, 69) // 'E'
  208. dv.setUint32(off, len - 1)
  209. off += 4
  210. if (portal.length) {
  211. off += buf.writeString(portal, off)
  212. }
  213. dv.setUint8(off++, 0)
  214. dv.setUint32(off, maxRows)
  215. off += 4
  216. fun.messages.exec.len = off - fun.messages.exec.start
  217. // Sync Message
  218. fun.messages.sync.start = off
  219. dv.setUint8(off++, 83) // 'S'
  220. dv.setUint32(off, 4)
  221. off += 4
  222. fun.messages.sync.len = off - fun.messages.sync.start
  223. // Flush Message
  224. fun.messages.flush.start = off
  225. dv.setUint8(off++, 72) // 'H'
  226. dv.setUint32(off, 4)
  227. off += 4
  228. fun.messages.flush.len = off - fun.messages.flush.start
  229. fun.size = off
  230. fun.buf = buf.slice(0, off)
  231. Object.assign(query, fun)
  232. let readString = just.sys.readString
  233. if (query.htmlEscape) {
  234. readString = html.escape
  235. }
  236. query.getRows = () => {
  237. const { buf, dv } = parser
  238. const { fields } = query
  239. const { start, rows } = parser.query
  240. let off = start
  241. const result = []
  242. let i = 0
  243. let j = 0
  244. let row
  245. for (i = 0; i < rows; i++) {
  246. off += 5
  247. const cols = dv.getUint16(off)
  248. off += 2
  249. row = Array(cols)
  250. result.push(row)
  251. for (j = 0; j < cols; j++) {
  252. len = dv.getUint32(off)
  253. const { oid, format } = (fields[j] || fields[0])
  254. off += 4
  255. if (format === 0) { // Non-Binary
  256. if (oid === INT4OID) {
  257. row[j] = parseInt(buf.readString(len, off), 10)
  258. } else {
  259. row[j] = readString(buf, len, off)
  260. }
  261. } else {
  262. if (oid === INT4OID) {
  263. row[j] = dv.getInt32(off)
  264. } else {
  265. row[j] = buf.slice(off, off + len)
  266. }
  267. }
  268. off += len
  269. }
  270. }
  271. return result
  272. }
  273. query.getResult = () => parser.getResult()
  274. if (!onComplete) return query
  275. fun.prepare(true, err => {
  276. if (err) return onComplete(err)
  277. fun.describe(true, err => {
  278. if (err) return onComplete(err)
  279. onComplete()
  280. })
  281. })
  282. return query
  283. }
  284. function start (onStart) {
  285. callbacks.push(onStart)
  286. sock.write(startupMessage(config))
  287. }
  288. function authenticate (onAuthenticate) {
  289. callbacks.push(onAuthenticate)
  290. sock.write(md5AuthMessage({ user, pass, salt: parser.salt }))
  291. }
  292. function onMessage () {
  293. const { type } = parser
  294. if (type === CommandComplete) {
  295. callbacks.shift()()
  296. return
  297. }
  298. if (type === ReadyForQuery) {
  299. if (!sock.authenticated) {
  300. sock.authenticated = true
  301. callbacks.shift()()
  302. }
  303. return
  304. }
  305. if (type === ErrorResponse) {
  306. callbacks.shift()(new Error(getPGError(parser.errors)))
  307. return
  308. }
  309. if (type === AuthenticationOk || type === ParseComplete || type === RowDescription || type === NoData) callbacks.shift()()
  310. }
  311. const buf = new ArrayBuffer(64 * 1024)
  312. sock.authenticated = false
  313. const parser = sock.parser = createParser(buf)
  314. const callbacks = []
  315. const { user, pass } = config
  316. parser.onMessage = onMessage
  317. sock.authenticate = authenticate
  318. sock.sync = () => sock.write(syncMessage())
  319. sock.start = start
  320. sock.compile = compile
  321. sock.onData = bytes => parser.parse(bytes)
  322. sock.onClose = () => {
  323. just.error('pg socket closed')
  324. }
  325. sock.getParams = () => parser.parameters
  326. sock.size = () => callbacks.length
  327. sock.query = parser.query
  328. sock.buffer = buf
  329. return sock
  330. }
  331. function connect (config, onPGConnect) {
  332. lookup(config.hostname, (err, ip) => {
  333. if (err) {
  334. onPGConnect(err)
  335. return
  336. }
  337. config.address = ip
  338. const sock = createClient(config.address, config.port)
  339. sock.onClose = () => {
  340. just.error('pg socket closed')
  341. }
  342. sock.onConnect = err => {
  343. onPGConnect(err, setupSocket(sock, config))
  344. return sock.buffer
  345. }
  346. sock.connect()
  347. })
  348. }
  349. module.exports = { connect, constants, getMessageName }