|
|
@@ -1598,6 +1598,7 @@ private:
|
|
|
int64_t GetMemLimit() const final { return m_iRtMemLimit; }
|
|
|
|
|
|
bool LoadEmbeddingModels ( CSphString & sError );
|
|
|
+ const CSphQuery * SetupAutoEmbeddings ( const CSphQuery & tQuery, CSphQuery & tUpdatedQuery, const ISphSchema & tMatchSchema, CSphString & sError ) const;
|
|
|
bool VerifyKNN ( InsertDocData_c & tDoc, CSphString & sError ) const;
|
|
|
|
|
|
template<typename PRED>
|
|
|
@@ -8069,6 +8070,63 @@ ConstRtData FilterReaderChunks ( ConstRtData tOrigin, const VecTraits_T<int64_t>
|
|
|
return { pConstChunks, pConstSegments };
|
|
|
}
|
|
|
|
|
|
+
|
|
|
+const CSphQuery * RtIndex_c::SetupAutoEmbeddings ( const CSphQuery & tQuery, CSphQuery & tUpdatedQuery, const ISphSchema & tMatchSchema, CSphString & sError ) const
|
|
|
+{
|
|
|
+ auto & tKNN = tQuery.m_tKnnSettings;
|
|
|
+ if ( !m_pEmbeddings || tKNN.m_sAttr.IsEmpty() || tKNN.m_sEmbStr.IsEmpty() )
|
|
|
+ return &tQuery;
|
|
|
+
|
|
|
+ auto pAttr = m_tSchema.GetAttr ( tKNN.m_sAttr.cstr() );
|
|
|
+ if ( !pAttr )
|
|
|
+ {
|
|
|
+ sError.SetSprintf ( "KNN search attribute '%s' not found", tKNN.m_sAttr.cstr() );
|
|
|
+ return nullptr;
|
|
|
+ }
|
|
|
+
|
|
|
+ knn::TextToEmbeddings_i * pModel = m_pEmbeddings->GetModel ( tKNN.m_sAttr );
|
|
|
+ if ( !pModel )
|
|
|
+ {
|
|
|
+ sError.SetSprintf ( "No model loaded for auto embeddings attribute '%s'", tKNN.m_sAttr.cstr() );
|
|
|
+ return nullptr;
|
|
|
+ }
|
|
|
+
|
|
|
+ tUpdatedQuery = tQuery;
|
|
|
+
|
|
|
+ std::vector<std::vector<float>> dEmbeddings;
|
|
|
+ std::vector<std::string_view> dTexts;
|
|
|
+ dTexts.push_back( tKNN.m_sEmbStr.cstr() );
|
|
|
+
|
|
|
+ std::string sConvertError;
|
|
|
+ if ( !pModel->Convert ( dTexts, dEmbeddings, sConvertError ) )
|
|
|
+ {
|
|
|
+ sError.SetSprintf ( "Error generating embeddings for attribute '%s' : %s", tKNN.m_sAttr.cstr(), sConvertError.c_str() );
|
|
|
+ return nullptr;
|
|
|
+ }
|
|
|
+
|
|
|
+ if ( dEmbeddings.size()!=1 )
|
|
|
+ {
|
|
|
+ sError.SetSprintf ( "Error generating embeddings for attribute '%s'", tKNN.m_sAttr.cstr() );
|
|
|
+ return nullptr;
|
|
|
+ }
|
|
|
+
|
|
|
+ int iEmbDim = dEmbeddings[0].size();
|
|
|
+ if ( iEmbDim!=pAttr->m_tKNN.m_iDims )
|
|
|
+ {
|
|
|
+ sError.SetSprintf ( "Auto generated embedding dimension mismatch: expected %d, got %d", pAttr->m_tKNN.m_iDims, iEmbDim );
|
|
|
+ return nullptr;
|
|
|
+ }
|
|
|
+
|
|
|
+ tUpdatedQuery.m_tKnnSettings.m_dVec.Resize(iEmbDim);
|
|
|
+ memcpy ( tUpdatedQuery.m_tKnnSettings.m_dVec.Begin(), dEmbeddings[0].data(), iEmbDim*sizeof(float) );
|
|
|
+
|
|
|
+ for ( int i = 0; i < tMatchSchema.GetAttrsCount(); i++ )
|
|
|
+ if ( tMatchSchema.GetAttr(i).m_pExpr )
|
|
|
+ tMatchSchema.GetAttr(i).m_pExpr->Command ( SPH_EXPR_SET_KNN_VEC, &tUpdatedQuery.m_tKnnSettings.m_dVec );
|
|
|
+
|
|
|
+ return &tUpdatedQuery;
|
|
|
+}
|
|
|
+
|
|
|
// FIXME! missing MVA, index_exact_words support
|
|
|
// FIXME? any chance to factor out common backend agnostic code?
|
|
|
// FIXME? do we need to support pExtraFilters?
|
|
|
@@ -8159,22 +8217,40 @@ bool RtIndex_c::MultiQuery ( CSphQueryResult & tResult, const CSphQuery & tQuery
|
|
|
|
|
|
SwitchProfile ( pProfiler, SPH_QSTATE_INIT );
|
|
|
|
|
|
+ // select the sorter with max schema
|
|
|
+ // uses GetAttrsCount to get working facets (was GetRowSize)
|
|
|
+ int iMaxSchemaIndex, iMatchPoolSize;
|
|
|
+ std::tie ( iMaxSchemaIndex, iMatchPoolSize ) = GetMaxSchemaIndexAndMatchCapacity ( dSorters );
|
|
|
+
|
|
|
+ if ( iMaxSchemaIndex==-1 )
|
|
|
+ return false;
|
|
|
+
|
|
|
+ const ISphSchema & tMaxSorterSchema = *( dSorters[iMaxSchemaIndex]->GetSchema ());
|
|
|
+ auto dSorterSchemas = SorterSchemas ( dSorters, iMaxSchemaIndex );
|
|
|
+
|
|
|
+ CSphQuery tUpdatedQuery;
|
|
|
+ const CSphQuery * pQueryToRun = SetupAutoEmbeddings ( tQuery, tUpdatedQuery, tMaxSorterSchema, tMeta.m_sError );
|
|
|
+ if ( !pQueryToRun )
|
|
|
+ return false;
|
|
|
+
|
|
|
+ auto & tQueryToRun = *pQueryToRun;
|
|
|
+
|
|
|
// FIXME! each result will point to its own MVA and string pools
|
|
|
|
|
|
//////////////////////
|
|
|
// search disk chunks
|
|
|
//////////////////////
|
|
|
|
|
|
- tMeta.m_bHasPrediction = tQuery.m_iMaxPredictedMsec>0;
|
|
|
+ tMeta.m_bHasPrediction = tQueryToRun.m_iMaxPredictedMsec>0;
|
|
|
|
|
|
MiniTimer_c dTimerGuard;
|
|
|
- int64_t tmMaxTimer = dTimerGuard.Engage ( tQuery.m_uMaxQueryMsec ); // max_query_time
|
|
|
+ int64_t tmMaxTimer = dTimerGuard.Engage ( tQueryToRun.m_uMaxQueryMsec ); // max_query_time
|
|
|
|
|
|
SorterSchemaTransform_c tSSTransform ( dDiskChunks.GetLength(), tArgs.m_bFinalizeSorters );
|
|
|
|
|
|
if ( !dDiskChunks.IsEmpty() )
|
|
|
{
|
|
|
- if ( !QueryDiskChunks ( tQuery, tMeta, tArgs, tGuard, dSorters, pProfiler, bGotLocalDF, pLocalDocs, iTotalDocs, GetName(), tSSTransform, tmMaxTimer ) )
|
|
|
+ if ( !QueryDiskChunks ( tQueryToRun, tMeta, tArgs, tGuard, dSorters, pProfiler, bGotLocalDF, pLocalDocs, iTotalDocs, GetName(), tSSTransform, tmMaxTimer ) )
|
|
|
return false;
|
|
|
}
|
|
|
|
|
|
@@ -8184,19 +8260,8 @@ bool RtIndex_c::MultiQuery ( CSphQueryResult & tResult, const CSphQuery & tQuery
|
|
|
|
|
|
SwitchProfile ( pProfiler, SPH_QSTATE_INIT );
|
|
|
|
|
|
- // select the sorter with max schema
|
|
|
- // uses GetAttrsCount to get working facets (was GetRowSize)
|
|
|
- int iMaxSchemaIndex, iMatchPoolSize;
|
|
|
- std::tie ( iMaxSchemaIndex, iMatchPoolSize ) = GetMaxSchemaIndexAndMatchCapacity ( dSorters );
|
|
|
-
|
|
|
- if ( iMaxSchemaIndex==-1 )
|
|
|
- return false;
|
|
|
-
|
|
|
- const ISphSchema & tMaxSorterSchema = *( dSorters[iMaxSchemaIndex]->GetSchema ());
|
|
|
- auto dSorterSchemas = SorterSchemas ( dSorters, iMaxSchemaIndex );
|
|
|
-
|
|
|
// setup calculations and result schema
|
|
|
- CSphQueryContext tCtx ( tQuery );
|
|
|
+ CSphQueryContext tCtx ( tQueryToRun );
|
|
|
tCtx.m_pProfile = pProfiler;
|
|
|
tCtx.m_pLocalDocs = pLocalDocs;
|
|
|
tCtx.m_iTotalDocs = iTotalDocs;
|
|
|
@@ -8207,7 +8272,7 @@ bool RtIndex_c::MultiQuery ( CSphQueryResult & tResult, const CSphQuery & tQuery
|
|
|
tTermSetup.SetDict ( pDict );
|
|
|
tTermSetup.m_pIndex = this;
|
|
|
tTermSetup.m_iDynamicRowitems = tMaxSorterSchema.GetDynamicSize();
|
|
|
- tTermSetup.m_iMaxTimer = dTimerGuard.Engage ( tQuery.m_uMaxQueryMsec ); // max_query_time
|
|
|
+ tTermSetup.m_iMaxTimer = dTimerGuard.Engage ( tQueryToRun.m_uMaxQueryMsec ); // max_query_time
|
|
|
tTermSetup.m_pWarning = &tMeta.m_sWarning;
|
|
|
tTermSetup.SetSegment ( -1 );
|
|
|
tTermSetup.m_pCtx = &tCtx;
|
|
|
@@ -8215,17 +8280,17 @@ bool RtIndex_c::MultiQuery ( CSphQueryResult & tResult, const CSphQuery & tQuery
|
|
|
|
|
|
// setup prediction constrain
|
|
|
CSphQueryStats tQueryStats;
|
|
|
- int64_t iNanoBudget = (int64_t)(tQuery.m_iMaxPredictedMsec) * 1000000; // from milliseconds to nanoseconds
|
|
|
+ int64_t iNanoBudget = (int64_t)(tQueryToRun.m_iMaxPredictedMsec) * 1000000; // from milliseconds to nanoseconds
|
|
|
tQueryStats.m_pNanoBudget = &iNanoBudget;
|
|
|
if ( tMeta.m_bHasPrediction )
|
|
|
tTermSetup.m_pStats = &tQueryStats;
|
|
|
|
|
|
// bind weights
|
|
|
- tCtx.BindWeights ( tQuery, m_tSchema, tMeta.m_sWarning );
|
|
|
+ tCtx.BindWeights ( tQueryToRun, m_tSchema, tMeta.m_sWarning );
|
|
|
|
|
|
CSphVector<BYTE> dFiltered;
|
|
|
- const BYTE * sModifiedQuery = (const BYTE *)tQuery.m_sQuery.cstr();
|
|
|
- FieldFilterOptions_t tFFOptions { tQuery.m_eJiebaMode };
|
|
|
+ const BYTE * sModifiedQuery = (const BYTE *)tQueryToRun.m_sQuery.cstr();
|
|
|
+ FieldFilterOptions_t tFFOptions { tQueryToRun.m_eJiebaMode };
|
|
|
|
|
|
if ( m_pFieldFilter && sModifiedQuery && m_pFieldFilter->Clone ( &tFFOptions )->Apply ( sModifiedQuery, dFiltered, true ) )
|
|
|
sModifiedQuery = dFiltered.Begin();
|
|
|
@@ -8248,13 +8313,13 @@ bool RtIndex_c::MultiQuery ( CSphQueryResult & tResult, const CSphQuery & tQuery
|
|
|
if ( !bFullscan )
|
|
|
{
|
|
|
assert ( m_pQueryTokenizer.Ptr() && m_pQueryTokenizerJson.Ptr() );
|
|
|
- if ( !pQueryParser->ParseQuery ( tParsed, (const char *)sModifiedQuery, &tQuery, m_pQueryTokenizer, m_pQueryTokenizerJson, &m_tSchema, pDict, m_tSettings, &m_tMorphFields ) )
|
|
|
+ if ( !pQueryParser->ParseQuery ( tParsed, (const char *)sModifiedQuery, &tQueryToRun, m_pQueryTokenizer, m_pQueryTokenizerJson, &m_tSchema, pDict, m_tSettings, &m_tMorphFields ) )
|
|
|
{
|
|
|
tMeta.m_sError = tParsed.m_sParseError;
|
|
|
iStackNeed = 0;
|
|
|
} else
|
|
|
{
|
|
|
- iStackNeed = PrepareFTSearch ( this, IsStarDict ( m_bKeywordDict ), m_bKeywordDict, m_tMutableSettings.m_iExpandKeywords, m_iExpansionLimit, m_tSettings, tQuery,(cRefCountedRefPtrGeneric_t) tGuard.m_tSegmentsAndChunks.m_pSegs, pDict, tMeta, pProfiler, &tPayloads, tParsed );
|
|
|
+ iStackNeed = PrepareFTSearch ( this, IsStarDict ( m_bKeywordDict ), m_bKeywordDict, m_tMutableSettings.m_iExpandKeywords, m_iExpansionLimit, m_tSettings, tQueryToRun,(cRefCountedRefPtrGeneric_t) tGuard.m_tSegmentsAndChunks.m_pSegs, pDict, tMeta, pProfiler, &tPayloads, tParsed );
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@@ -8280,19 +8345,19 @@ bool RtIndex_c::MultiQuery ( CSphQueryResult & tResult, const CSphQuery & tQuery
|
|
|
|
|
|
CSphVector<CSphFilterSettings> dTransformedFilters; // holds filter settings if they were modified. filters hold pointers to those settings
|
|
|
CSphVector<FilterTreeItem_t> dTransformedFilterTree;
|
|
|
- if ( !SetupFilters ( tQuery, tMaxSorterSchema, m_tSchema, bParsedFullscan, tCtx, dTransformedFilters, dTransformedFilterTree, dSorterSchemas, tMeta ) )
|
|
|
+ if ( !SetupFilters ( tQueryToRun, tMaxSorterSchema, m_tSchema, bParsedFullscan, tCtx, dTransformedFilters, dTransformedFilterTree, dSorterSchemas, tMeta ) )
|
|
|
return false;
|
|
|
|
|
|
bool bResult;
|
|
|
if ( bParsedFullscan )
|
|
|
- bResult = DoFullScanQuery ( tGuard.m_dRamSegs, tMaxSorterSchema, tQuery, tArgs, m_iStride, tmMaxTimer, pProfiler, tCtx, dSorters, tMeta );
|
|
|
+ bResult = DoFullScanQuery ( tGuard.m_dRamSegs, tMaxSorterSchema, tQueryToRun, tArgs, m_iStride, tmMaxTimer, pProfiler, tCtx, dSorters, tMeta );
|
|
|
else
|
|
|
{
|
|
|
CSphMultiQueryArgs tFTArgs ( tArgs.m_iIndexWeight );
|
|
|
tFTArgs.m_bFinalizeSorters = tArgs.m_bFinalizeSorters;
|
|
|
tMeta.m_bBigram = ( m_tSettings.m_eBigramIndex!=SPH_BIGRAM_NONE );
|
|
|
|
|
|
- bResult = DoFullTextSearch ( tGuard.m_dRamSegs, tMaxSorterSchema, tQuery, tFTArgs, iMatchPoolSize, iStackNeed, tTermSetup, pProfiler, tCtx, dSorters, tParsed, tMeta, dSorters.GetLength()==1 ? dSorters[0] : nullptr );
|
|
|
+ bResult = DoFullTextSearch ( tGuard.m_dRamSegs, tMaxSorterSchema, tQueryToRun, tFTArgs, iMatchPoolSize, iStackNeed, tTermSetup, pProfiler, tCtx, dSorters, tParsed, tMeta, dSorters.GetLength()==1 ? dSorters[0] : nullptr );
|
|
|
}
|
|
|
|
|
|
if (!bResult)
|