Browse Source

feat: generate auto embeddings in queries

Ilya Kuznetsov 7 months ago
parent
commit
59c413879f
7 changed files with 134 additions and 39 deletions
  1. 17 4
      src/knnmisc.cpp
  2. 1 1
      src/queuecreator.cpp
  3. 12 7
      src/searchdsql.cpp
  4. 1 0
      src/sphinx.h
  5. 1 0
      src/sphinxexpr.h
  6. 12 2
      src/sphinxql.y
  7. 90 25
      src/sphinxrt.cpp

+ 17 - 4
src/knnmisc.cpp

@@ -119,20 +119,20 @@ private:
 
 	util::Span_T<const  knn::DocDist_t>	m_dData;
 	mutable const knn::DocDist_t *		m_pStart = nullptr;
+
+	void				SetAnchor ( const CSphVector<float> & dAnchor );
 };
 
 
 Expr_KNNDist_c::Expr_KNNDist_c ( const CSphVector<float> & dAnchor, const CSphColumnInfo & tAttr )
-	: m_dAnchor ( dAnchor )
-	, m_tAttr ( tAttr )
+	: m_tAttr ( tAttr )
 {
 	knn::IndexSettings_t tDistSettings = tAttr.m_tKNN;
 	tDistSettings.m_eQuantization = knn::Quantization_e::NONE; // we operate on non-quantized data
 	CSphString sError; // fixme! report it
 	m_pDistCalc = CreateKNNDistanceCalc ( tDistSettings, sError );
 
-	if ( tAttr.m_tKNN.m_eHNSWSimilarity==knn::HNSWSimilarity_e::COSINE )
-		NormalizeVec(m_dAnchor);
+	SetAnchor(dAnchor);
 }
 
 
@@ -175,6 +175,10 @@ void Expr_KNNDist_c::Command ( ESphExprCommand eCmd, void * pArg )
 		m_pStart = nullptr;
 		break;
 
+	case SPH_EXPR_SET_KNN_VEC:
+		SetAnchor ( *(const CSphVector<float>*)pArg );
+		break;
+
 	default:
 		break;
 	}
@@ -211,6 +215,15 @@ float Expr_KNNDist_c::CalcDist ( const CSphMatch & tMatch ) const
 	return m_pDistCalc->CalcDist ( { dData.Begin(), (size_t)dData.GetLength() }, { m_dAnchor.Begin(), (size_t)m_dAnchor.GetLength() } );
 }
 
+
+void Expr_KNNDist_c::SetAnchor ( const CSphVector<float> & dAnchor )
+{
+	m_dAnchor = dAnchor;
+
+	if ( m_tAttr.m_tKNN.m_eHNSWSimilarity==knn::HNSWSimilarity_e::COSINE )
+		NormalizeVec(m_dAnchor);
+}
+
 /////////////////////////////////////////////////////////////////////
 
 class Expr_KNNDistRescore_c : public Expr_KNNDist_c

+ 1 - 1
src/queuecreator.cpp

@@ -1523,7 +1523,7 @@ bool QueueCreator_c::AddKNNDistColumn()
 		return false;
 	}
 
-	if ( pAttr->m_tKNN.m_iDims!=tKNN.m_dVec.GetLength() )
+	if ( tKNN.m_sEmbStr.IsEmpty() && pAttr->m_tKNN.m_iDims!=tKNN.m_dVec.GetLength() )
 	{
 		m_sError.SetSprintf ( "KNN index '%s' requires a vector of %d entries; %d entries specified", tKNN.m_sAttr.cstr(), pAttr->m_tKNN.m_iDims, tKNN.m_dVec.GetLength() );
 		return false;

+ 12 - 7
src/searchdsql.cpp

@@ -379,7 +379,7 @@ public:
 	bool			AddSchemaItem ( SqlNode_t * pNode );
 	bool			SetMatch ( const SqlNode_t & tValue );
 	bool			AddMatch ( const SqlNode_t & tValue, const SqlNode_t & tIndex );
-	bool			SetKNN ( const SqlNode_t & tAttr, const SqlNode_t & tK, const SqlNode_t & tValues, const CSphVector<CSphNamedVariant> * pOpts );
+	bool			SetKNN ( const SqlNode_t & tAttr, const SqlNode_t & tK, const SqlNode_t & tValues, const CSphVector<CSphNamedVariant> * pOpts, bool bAutoEmb );
 	void			AddConst ( int iList, const SqlNode_t& tValue );
 	void			SetLocalStatement ( const SqlNode_t & tName );
 	bool			AddFloatRangeFilter ( const SqlNode_t & tAttr, float fMin, float fMax, bool bHasEqual, bool bExclude=false );
@@ -1417,7 +1417,7 @@ static bool ParseKNNOption ( const CSphNamedVariant & tOpt, KnnSearchSettings_t
 }
 
 
-bool SqlParser_c::SetKNN ( const SqlNode_t & tAttr, const SqlNode_t & tK, const SqlNode_t & tValues, const CSphVector<CSphNamedVariant> * pOpts )
+bool SqlParser_c::SetKNN ( const SqlNode_t & tAttr, const SqlNode_t & tK, const SqlNode_t & tValues, const CSphVector<CSphNamedVariant> * pOpts, bool bAutoEmb )
 {
 	auto & tKNN = m_pQuery->m_tKnnSettings;
 
@@ -1433,12 +1433,17 @@ bool SqlParser_c::SetKNN ( const SqlNode_t & tAttr, const SqlNode_t & tK, const
 				return false;
 			}
 
-	if ( tValues.m_iValues>=0 )
+	if ( bAutoEmb )
+		ToString ( tKNN.m_sEmbStr, tValues );
+	else
 	{
-		const auto & dValues = GetMvaVec ( tValues.m_iValues );
-		tKNN.m_dVec.Reserve ( dValues.GetLength() );
-		for ( const auto & i : dValues )
-			tKNN.m_dVec.Add( i.m_fValue );
+		if ( tValues.m_iValues>=0 )
+		{
+			const auto & dValues = GetMvaVec ( tValues.m_iValues );
+			tKNN.m_dVec.Reserve ( dValues.GetLength() );
+			for ( const auto & i : dValues )
+				tKNN.m_dVec.Add( i.m_fValue );
+		}
 	}
 	
 	return true;

+ 1 - 0
src/sphinx.h

@@ -517,6 +517,7 @@ struct KnnSearchSettings_t
 	bool			m_bRescore = false;		///< KNN rescoring
 	float			m_fOversampling = 1.0f;	///< KNN oversampling
 	CSphVector<float> m_dVec;				///< KNN anchor vector
+	CSphString		m_sEmbStr;				///< string to generate embeddings from
 };
 
 /// search query. Pure struct, no member functions

+ 1 - 0
src/sphinxexpr.h

@@ -86,6 +86,7 @@ enum ESphExprCommand
 	SPH_EXPR_SET_DOCSTORE_DOCID,	///< interface to fetch docs by docid (postlimit stage)
 	SPH_EXPR_SET_QUERY,
 	SPH_EXPR_SET_EXTRA_DATA,
+	SPH_EXPR_SET_KNN_VEC,
 	SPH_EXPR_GET_DEPENDENT_COLS,	///< used to determine proper evaluating stage
 	SPH_EXPR_GET_GEODIST_SETTINGS,
 	SPH_EXPR_GET_POLY2D_BBOX,

+ 12 - 2
src/sphinxql.y

@@ -739,12 +739,22 @@ on_clause:
 knn_item:
 	TOK_KNN '(' ident ',' const_int ',' '(' const_list ')' ')'
 		{
-			if ( !pParser->SetKNN ( $3, $5, $8, nullptr ) )
+			if ( !pParser->SetKNN ( $3, $5, $8, nullptr, false ) )
+				YYERROR;
+		}
+	| TOK_KNN '(' ident ',' const_int ',' TOK_QUOTED_STRING ')'
+		{
+			if ( !pParser->SetKNN ( $3, $5, $7, nullptr, true ) )
 				YYERROR;
 		}
 	| TOK_KNN '(' ident ',' const_int ',' '(' const_list ')' ',' '{' named_const_list '}' ')'
 		{
-			if ( !pParser->SetKNN ( $3, $5, $8, &( pParser->GetNamedVec ( $4.GetValueInt() ) ) ) )
+			if ( !pParser->SetKNN ( $3, $5, $8, &( pParser->GetNamedVec ( $4.GetValueInt() ) ), false ) )
+				YYERROR;
+		}
+	| TOK_KNN '(' ident ',' const_int ',' TOK_QUOTED_STRING ',' '{' named_const_list '}' ')'
+		{
+			if ( !pParser->SetKNN ( $3, $5, $7, &( pParser->GetNamedVec ( $4.GetValueInt() ) ), true ) )
 				YYERROR;
 		}
 	;

+ 90 - 25
src/sphinxrt.cpp

@@ -1598,6 +1598,7 @@ private:
 	int64_t						GetMemLimit() const final { return m_iRtMemLimit; }
 
 	bool						LoadEmbeddingModels ( CSphString & sError );
+	const CSphQuery *			SetupAutoEmbeddings ( const CSphQuery & tQuery, CSphQuery & tUpdatedQuery, const ISphSchema & tMatchSchema, CSphString & sError ) const;
 	bool						VerifyKNN ( InsertDocData_c & tDoc, CSphString & sError ) const;
 
 	template<typename PRED>
@@ -8069,6 +8070,63 @@ ConstRtData FilterReaderChunks ( ConstRtData tOrigin, const VecTraits_T<int64_t>
 	return { pConstChunks, pConstSegments };
 }
 
+
+const CSphQuery * RtIndex_c::SetupAutoEmbeddings ( const CSphQuery & tQuery, CSphQuery & tUpdatedQuery, const ISphSchema & tMatchSchema, CSphString & sError ) const
+{
+	auto & tKNN = tQuery.m_tKnnSettings;
+	if ( !m_pEmbeddings || tKNN.m_sAttr.IsEmpty() || tKNN.m_sEmbStr.IsEmpty() )
+		return &tQuery;
+
+	auto pAttr = m_tSchema.GetAttr ( tKNN.m_sAttr.cstr() );
+	if ( !pAttr )
+	{
+		sError.SetSprintf ( "KNN search attribute '%s' not found", tKNN.m_sAttr.cstr() );
+		return nullptr;
+	}
+
+	knn::TextToEmbeddings_i * pModel = m_pEmbeddings->GetModel ( tKNN.m_sAttr );
+	if ( !pModel )
+	{
+		sError.SetSprintf ( "No model loaded for auto embeddings attribute '%s'", tKNN.m_sAttr.cstr() );
+		return nullptr;
+	}
+
+	tUpdatedQuery = tQuery;
+
+	std::vector<std::vector<float>> dEmbeddings;
+	std::vector<std::string_view> dTexts;
+	dTexts.push_back( tKNN.m_sEmbStr.cstr() );
+
+	std::string sConvertError;
+	if ( !pModel->Convert ( dTexts, dEmbeddings, sConvertError ) )
+	{
+		sError.SetSprintf ( "Error generating embeddings for attribute '%s' : %s", tKNN.m_sAttr.cstr(), sConvertError.c_str() );
+		return nullptr;
+	}
+
+	if ( dEmbeddings.size()!=1 )
+	{
+		sError.SetSprintf ( "Error generating embeddings for attribute '%s'", tKNN.m_sAttr.cstr() );
+		return nullptr;
+	}
+
+	int iEmbDim = dEmbeddings[0].size();
+	if ( iEmbDim!=pAttr->m_tKNN.m_iDims )
+	{
+		sError.SetSprintf ( "Auto generated embedding dimension mismatch: expected %d, got %d", pAttr->m_tKNN.m_iDims, iEmbDim );
+		return nullptr;
+	}
+
+	tUpdatedQuery.m_tKnnSettings.m_dVec.Resize(iEmbDim);
+	memcpy ( tUpdatedQuery.m_tKnnSettings.m_dVec.Begin(), dEmbeddings[0].data(), iEmbDim*sizeof(float) );
+
+	for ( int i = 0; i < tMatchSchema.GetAttrsCount(); i++ )
+		if ( tMatchSchema.GetAttr(i).m_pExpr )
+			tMatchSchema.GetAttr(i).m_pExpr->Command ( SPH_EXPR_SET_KNN_VEC, &tUpdatedQuery.m_tKnnSettings.m_dVec );
+
+	return &tUpdatedQuery;
+}
+
 // FIXME! missing MVA, index_exact_words support
 // FIXME? any chance to factor out common backend agnostic code?
 // FIXME? do we need to support pExtraFilters?
@@ -8159,22 +8217,40 @@ bool RtIndex_c::MultiQuery ( CSphQueryResult & tResult, const CSphQuery & tQuery
 
 	SwitchProfile ( pProfiler, SPH_QSTATE_INIT );
 
+	// select the sorter with max schema
+	// uses GetAttrsCount to get working facets (was GetRowSize)
+	int iMaxSchemaIndex, iMatchPoolSize;
+	std::tie ( iMaxSchemaIndex, iMatchPoolSize ) = GetMaxSchemaIndexAndMatchCapacity ( dSorters );
+
+	if ( iMaxSchemaIndex==-1 )
+		return false;
+
+	const ISphSchema & tMaxSorterSchema = *( dSorters[iMaxSchemaIndex]->GetSchema ());
+	auto dSorterSchemas = SorterSchemas ( dSorters, iMaxSchemaIndex );
+
+	CSphQuery tUpdatedQuery;
+	const CSphQuery * pQueryToRun = SetupAutoEmbeddings ( tQuery, tUpdatedQuery, tMaxSorterSchema, tMeta.m_sError );
+	if ( !pQueryToRun )
+		return false;
+
+	auto & tQueryToRun = *pQueryToRun;
+
 	// FIXME! each result will point to its own MVA and string pools
 
 	//////////////////////
 	// search disk chunks
 	//////////////////////
 
-	tMeta.m_bHasPrediction = tQuery.m_iMaxPredictedMsec>0;
+	tMeta.m_bHasPrediction = tQueryToRun.m_iMaxPredictedMsec>0;
 
 	MiniTimer_c dTimerGuard;
-	int64_t tmMaxTimer = dTimerGuard.Engage ( tQuery.m_uMaxQueryMsec ); // max_query_time
+	int64_t tmMaxTimer = dTimerGuard.Engage ( tQueryToRun.m_uMaxQueryMsec ); // max_query_time
 
 	SorterSchemaTransform_c tSSTransform ( dDiskChunks.GetLength(), tArgs.m_bFinalizeSorters );
 
 	if ( !dDiskChunks.IsEmpty() )
 	{
-		if ( !QueryDiskChunks ( tQuery, tMeta, tArgs, tGuard, dSorters, pProfiler, bGotLocalDF, pLocalDocs, iTotalDocs, GetName(), tSSTransform, tmMaxTimer ) )
+		if ( !QueryDiskChunks ( tQueryToRun, tMeta, tArgs, tGuard, dSorters, pProfiler, bGotLocalDF, pLocalDocs, iTotalDocs, GetName(), tSSTransform, tmMaxTimer ) )
 			return false;
 	}
 
@@ -8184,19 +8260,8 @@ bool RtIndex_c::MultiQuery ( CSphQueryResult & tResult, const CSphQuery & tQuery
 
 	SwitchProfile ( pProfiler, SPH_QSTATE_INIT );
 
-	// select the sorter with max schema
-	// uses GetAttrsCount to get working facets (was GetRowSize)
-	int iMaxSchemaIndex, iMatchPoolSize;
-	std::tie ( iMaxSchemaIndex, iMatchPoolSize ) = GetMaxSchemaIndexAndMatchCapacity ( dSorters );
-
-	if ( iMaxSchemaIndex==-1 )
-		return false;
-
-	const ISphSchema & tMaxSorterSchema = *( dSorters[iMaxSchemaIndex]->GetSchema ());
-	auto dSorterSchemas = SorterSchemas ( dSorters, iMaxSchemaIndex );
-
 	// setup calculations and result schema
-	CSphQueryContext tCtx ( tQuery );
+	CSphQueryContext tCtx ( tQueryToRun );
 	tCtx.m_pProfile = pProfiler;
 	tCtx.m_pLocalDocs = pLocalDocs;
 	tCtx.m_iTotalDocs = iTotalDocs;
@@ -8207,7 +8272,7 @@ bool RtIndex_c::MultiQuery ( CSphQueryResult & tResult, const CSphQuery & tQuery
 	tTermSetup.SetDict ( pDict );
 	tTermSetup.m_pIndex = this;
 	tTermSetup.m_iDynamicRowitems = tMaxSorterSchema.GetDynamicSize();
-	tTermSetup.m_iMaxTimer = dTimerGuard.Engage ( tQuery.m_uMaxQueryMsec ); // max_query_time
+	tTermSetup.m_iMaxTimer = dTimerGuard.Engage ( tQueryToRun.m_uMaxQueryMsec ); // max_query_time
 	tTermSetup.m_pWarning = &tMeta.m_sWarning;
 	tTermSetup.SetSegment ( -1 );
 	tTermSetup.m_pCtx = &tCtx;
@@ -8215,17 +8280,17 @@ bool RtIndex_c::MultiQuery ( CSphQueryResult & tResult, const CSphQuery & tQuery
 
 	// setup prediction constrain
 	CSphQueryStats tQueryStats;
-	int64_t iNanoBudget = (int64_t)(tQuery.m_iMaxPredictedMsec) * 1000000; // from milliseconds to nanoseconds
+	int64_t iNanoBudget = (int64_t)(tQueryToRun.m_iMaxPredictedMsec) * 1000000; // from milliseconds to nanoseconds
 	tQueryStats.m_pNanoBudget = &iNanoBudget;
 	if ( tMeta.m_bHasPrediction )
 		tTermSetup.m_pStats = &tQueryStats;
 
 	// bind weights
-	tCtx.BindWeights ( tQuery, m_tSchema, tMeta.m_sWarning );
+	tCtx.BindWeights ( tQueryToRun, m_tSchema, tMeta.m_sWarning );
 
 	CSphVector<BYTE> dFiltered;
-	const BYTE * sModifiedQuery = (const BYTE *)tQuery.m_sQuery.cstr();
-	FieldFilterOptions_t tFFOptions { tQuery.m_eJiebaMode };
+	const BYTE * sModifiedQuery = (const BYTE *)tQueryToRun.m_sQuery.cstr();
+	FieldFilterOptions_t tFFOptions { tQueryToRun.m_eJiebaMode };
 
 	if ( m_pFieldFilter && sModifiedQuery && m_pFieldFilter->Clone ( &tFFOptions )->Apply ( sModifiedQuery, dFiltered, true ) )
 		sModifiedQuery = dFiltered.Begin();
@@ -8248,13 +8313,13 @@ bool RtIndex_c::MultiQuery ( CSphQueryResult & tResult, const CSphQuery & tQuery
 	if ( !bFullscan )
 	{
 		assert ( m_pQueryTokenizer.Ptr() && m_pQueryTokenizerJson.Ptr() );
-		if ( !pQueryParser->ParseQuery ( tParsed, (const char *)sModifiedQuery, &tQuery, m_pQueryTokenizer, m_pQueryTokenizerJson, &m_tSchema, pDict, m_tSettings, &m_tMorphFields ) )
+		if ( !pQueryParser->ParseQuery ( tParsed, (const char *)sModifiedQuery, &tQueryToRun, m_pQueryTokenizer, m_pQueryTokenizerJson, &m_tSchema, pDict, m_tSettings, &m_tMorphFields ) )
 		{
 			tMeta.m_sError = tParsed.m_sParseError;
 			iStackNeed = 0;
 		} else
 		{
-			iStackNeed = PrepareFTSearch ( this, IsStarDict ( m_bKeywordDict ), m_bKeywordDict, m_tMutableSettings.m_iExpandKeywords, m_iExpansionLimit, m_tSettings, tQuery,(cRefCountedRefPtrGeneric_t) tGuard.m_tSegmentsAndChunks.m_pSegs, pDict, tMeta, pProfiler, &tPayloads, tParsed );
+			iStackNeed = PrepareFTSearch ( this, IsStarDict ( m_bKeywordDict ), m_bKeywordDict, m_tMutableSettings.m_iExpandKeywords, m_iExpansionLimit, m_tSettings, tQueryToRun,(cRefCountedRefPtrGeneric_t) tGuard.m_tSegmentsAndChunks.m_pSegs, pDict, tMeta, pProfiler, &tPayloads, tParsed );
 		}
 	}
 
@@ -8280,19 +8345,19 @@ bool RtIndex_c::MultiQuery ( CSphQueryResult & tResult, const CSphQuery & tQuery
 
 	CSphVector<CSphFilterSettings> dTransformedFilters; // holds filter settings if they were modified. filters hold pointers to those settings
 	CSphVector<FilterTreeItem_t> dTransformedFilterTree;
-	if ( !SetupFilters ( tQuery, tMaxSorterSchema, m_tSchema, bParsedFullscan, tCtx, dTransformedFilters, dTransformedFilterTree, dSorterSchemas, tMeta ) )
+	if ( !SetupFilters ( tQueryToRun, tMaxSorterSchema, m_tSchema, bParsedFullscan, tCtx, dTransformedFilters, dTransformedFilterTree, dSorterSchemas, tMeta ) )
 		return false;
 
 	bool bResult;
 	if ( bParsedFullscan )
-		bResult = DoFullScanQuery ( tGuard.m_dRamSegs, tMaxSorterSchema, tQuery, tArgs, m_iStride, tmMaxTimer, pProfiler, tCtx, dSorters, tMeta );
+		bResult = DoFullScanQuery ( tGuard.m_dRamSegs, tMaxSorterSchema, tQueryToRun, tArgs, m_iStride, tmMaxTimer, pProfiler, tCtx, dSorters, tMeta );
 	else
 	{
 		CSphMultiQueryArgs tFTArgs ( tArgs.m_iIndexWeight );
 		tFTArgs.m_bFinalizeSorters = tArgs.m_bFinalizeSorters;
 		tMeta.m_bBigram = ( m_tSettings.m_eBigramIndex!=SPH_BIGRAM_NONE );
 
-		bResult = DoFullTextSearch ( tGuard.m_dRamSegs, tMaxSorterSchema, tQuery, tFTArgs, iMatchPoolSize, iStackNeed, tTermSetup, pProfiler, tCtx, dSorters, tParsed, tMeta, dSorters.GetLength()==1 ? dSorters[0] : nullptr );
+		bResult = DoFullTextSearch ( tGuard.m_dRamSegs, tMaxSorterSchema, tQueryToRun, tFTArgs, iMatchPoolSize, iStackNeed, tTermSetup, pProfiler, tCtx, dSorters, tParsed, tMeta, dSorters.GetLength()==1 ? dSorters[0] : nullptr );
 	}
 
 	if (!bResult)