zirobtc commited on
Commit
db0a14e
·
1 Parent(s): 9dd732c

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,776 +1,42 @@
1
- # =========================================
2
- # Entity Encoders
3
- # =========================================
4
- # These are generated offline/streaming and are the "vocabulary" for the model.
5
 
6
- <WalletEmbedding> # Embedding of a wallet's relationships, behavior, and history.
7
- <WalletEmbedding> = [
8
- // Data from the 'wallet_profiles' table (Wallet-level lifetime and daily/weekly stats)
9
- wallet_profiles_row: [
10
- // Core Info & Timestamps
11
- age, // No Contextual
12
- wallet_address, // Primary wallet identifier
13
-
14
 
15
- // 7. NEW: Deployed Token Aggregates (8 Features)
16
- deployed_tokens_count, // Total tokens created
17
- deployed_tokens_migrated_pct, // % that migrated
18
- deployed_tokens_avg_lifetime_sec, // Avg duration before dev selling
19
- deployed_tokens_avg_peak_mc_usd, // Avg peak marketcap
20
- deployed_tokens_median_peak_mc_usd,
21
-
22
- // Metadata & Balances
23
- balance, // Current SOL balance
24
-
25
- // Lifetime Transaction Counts (Total history)
26
- transfers_in_count, // Total native transfers received
27
- transfers_out_count, // Total native transfers sent
28
- spl_transfers_in_count, // Total SPL token transfers received
29
- spl_transfers_out_count,// Total SPL token transfers sent
30
-
31
- // Lifetime Trading Stats (Total history)
32
- total_buys_count, // Total buys across all tokens
33
- total_sells_count, // Total sells across all tokens
34
- total_winrate, // Overall trading winrate
35
-
36
- // 1-Day Stats (Realized P&L, Counts, Averages, Volume, Fees, Winrate)
37
- stats_1d_realized_profit_sol,
38
- stats_1d_realized_profit_pnl,
39
- stats_1d_buy_count,
40
- stats_1d_sell_count,
41
- stats_1d_transfer_in_count,
42
- stats_1d_transfer_out_count,
43
- stats_1d_avg_holding_period,
44
- stats_1d_total_bought_cost_sol,
45
- stats_1d_total_sold_income_sol,
46
- stats_1d_total_fee,
47
- stats_1d_winrate,
48
- stats_1d_tokens_traded,
49
-
50
- // 7-Day Stats (Realized P&L, Counts, Averages, Volume, Fees, Winrate)
51
- stats_7d_realized_profit_sol,
52
- stats_7d_realized_profit_pnl,
53
- stats_7d_buy_count,
54
- stats_7d_sell_count,
55
- stats_7d_transfer_in_count,
56
- stats_7d_transfer_out_count,
57
- stats_7d_avg_holding_period,
58
- stats_7d_total_bought_cost_sol,
59
- stats_7d_total_sold_income_sol,
60
- stats_7d_total_fee,
61
- stats_7d_winrate,
62
- stats_7d_tokens_traded,
63
 
64
- // 30 Days is to useless in the context
65
- ],
66
 
67
- // Data from the 'wallet_socials' table (Social media and profile info)
68
- wallet_socials_row: [
69
- has_pf_profile,
70
- has_twitter,
71
- has_telegram,
72
- is_exchange_wallet,
73
- username,
74
- ],
75
- // Data from the 'wallet_holdings' table (Token-level statistics for held tokens)
76
- wallet_holdings_pool: [
77
- <TokenVibeEmbedding>,
78
- holding_time, // How much he held the token (We check only tokens that currently is holding, or recently traded)
79
-
80
- balance_pct_to_supply, // Current quantity of the token held
81
-
82
- // History (Amounts & Costs)
83
- history_bought_amount_sol, // Total amount of token bought
84
- bought_amount_sol_pct_to_native_balance // Is he traded a lot of his wallet size
85
-
86
- // History (Counts)
87
- history_total_buys, // Total number of buy transactions
88
- history_total_sells, // Total number of sell transactions
89
-
90
- // Profit and Loss
91
- realized_profit_pnl, // Realized P&L as a percentage
92
- realized_profit_sol,
93
-
94
- // Transfers (Non-trade movements)
95
- history_transfer_in,
96
- history_transfer_out,
97
-
98
- avarage_trade_gap_seconds,
99
- total_priority_fees, // Total tips + Priority Fees
100
- ]
101
- ]
102
 
103
- <TokenVibeEmbedding> # Multimodal embedding of a token's identity
104
- <TokenVibeEmbedding> = [<TokenAddressEmbedding>, <NameEmbedding>, <SymbolEmbedding>, <ImageEmbedding>, protocol_id]
105
 
106
- <TextEmbedding> # Text embedding MultiModal processor.
107
- <MediaEmbedding> # Multimodal VIT encoder.
108
-
109
- # -----------------------------------------
110
- # 1. TradeEncoder
111
- # -----------------------------------------
112
-
113
- # Captures large-size trades from any wallet.
114
- [timestamp, 'LargeTrade', relative_ts, <WalletEmbedding>, trade_direction, sol_amount, dex_platform_id, priority_fee, mev_protection, token_amount_pct_of_holding, quote_amount_pct_of_holding, slippage, price_impact, success, is_bundle, total_usd]
115
-
116
- # Captures the high-signal "Dev Sold or Bought" event.
117
- [timestamp, 'Deployer_Trade', relative_ts, <CreatorWalletEmbedding>, trade_direction, sol_amount, dex_platform_id, priority_fee, mev_protection, token_amount_pct_of_holding, quote_amount_pct_of_holding, slippage, price_impact, success, is_bundle, total_usd]
118
-
119
- # Captures *all* trades from pre-defined high-P&L/win-rate, kol and known wallets.
120
- [timestamp, 'SmartWallet_Trade', relative_ts, <TraderWalletEmbedding>, trade_direction, sol_amount, dex_platform_id, priority_fee, mev_protection, token_amount_pct_of_holding, quote_amount_pct_of_holding, slippage, price_impact, success, is_bundle, total_usd]
121
-
122
- # Raw trades. Loaded in H/B/H Prefix (first ~10k) and Suffix (last ~5k).
123
- [timestamp, 'Trade', relative_ts, <TraderWalletEmbedding>, trade_direction, sol_amount, dex_platform_id, priority_fee, mev_protection, token_amount_pct_of_holding, quote_amount_pct_of_holding, slippage, price_impact, success, is_bundle, total_usd]
124
-
125
- # -----------------------------------------
126
- # 2. TransferEncoder
127
- # -----------------------------------------
128
-
129
- # Raw transfers. Loaded in H/B/H Prefix (all in first ~10k trade window) and Suffix (all in last ~5k trade window).
130
- [timestamp, 'Transfer', relative_ts, <SourceWalletEmbedding>, <DestinationWalletEmbedding>, token_amount, transfer_pct_of_total_supply, transfer_pct_of_holding, priority_fee]
131
-
132
- # Captures scarce, large transfers *after* the initial launch window.
133
- [timestamp, 'LargeTransfer', relative_ts, <FromWalletEmbedding>, <ToWalletEmbedding>, token_amount, transfer_pct_of_total_supply, transfer_pct_of_holding, priority_fee]
134
-
135
- # -----------------------------------------
136
- # 3. LifecycleEncoder
137
- # -----------------------------------------
138
-
139
- # The T0 event.
140
- [timestamp, 'Mint', 0, <CreatorWalletEmbedding>, <TokenVibeEmbedding>]
141
-
142
- # -----------------------------------------
143
- # 3. PoolEncoder
144
- # -----------------------------------------
145
-
146
- # Signals migration from launchpad to a real pool.
147
- [timestamp, 'PoolCreated', relative_ts, <ProviderWalletEmbedding>, protocol_id, <QuoteTokenVibeEmbedding>, base_amount, quote_amount, quote_pct_to_main_pool_balance, base_pct_to_main_pool_balance]
148
-
149
- # Signals LP addition or removal.
150
- [timestamp, 'LiquidityChange', relative_ts, <ProviderWalletEmbedding>, <QuoteTokenVibeEmbedding>, change_type_id, quote_amount, quote_pct_to_current_pool_balance]
151
-
152
- # Signals creator/dev taking platform fees.
153
- [timestamp, 'FeeCollected', relative_ts, <RecipientWalletEmbedding>, sol_amount, token_amount]
154
-
155
-
156
- # -----------------------------------------
157
- # SupplyEncoder
158
- # -----------------------------------------
159
-
160
- # Signals a supply reduction.
161
- [timestamp, 'TokenBurn', relative_ts, <BurnerWalletEmbedding>, amount_pct_of_total_supply, amount_tokens_burned]
162
-
163
- # Signals locked supply, e.g., for team/marketing.
164
- [timestamp, 'SupplyLock', relative_ts, <LockerWalletEmbedding>, amount_pct_of_total_supply, lock_duration]
165
-
166
- # -----------------------------------------
167
- # ChartEncoder
168
- # -----------------------------------------
169
-
170
- # (The "Sliding Window") This is the new chart event.
171
- [timestamp, 'Chart_Segment', relative_ts, OHLC_segment, chart_interval_id]
172
-
173
- # -----------------------------------------
174
- # PulseEncoder
175
- # -----------------------------------------
176
-
177
- # It is a low-frequency event (Dynamic Interval: 5min, 15min, or 1hr based on token age).
178
- [timestamp, 'OnChain_Snapshot', relative_ts, total_holders, smart_traders, kols, holder_growth_rate, top_10_holder_pct, sniper_holding_pct, rat_wallets_holding_pct, bundle_holding_pct, current_market_cap, liquidity, volume, buy_count, sell_count, total_txns, global_fees_paid]
179
-
180
- # -----------------------------------------
181
- # HoldersListEncoder
182
- # -----------------------------------------
183
-
184
- <HolderDistributionEmbedding> # Transformer-based embedding of the top holders (WalletEmbeddings + Pct).
185
-
186
- # Token-specific holder analysis.
187
- [timestamp, 'HolderSnapshot', relative_ts, <HolderDistributionEmbedding>]
188
-
189
-
190
- # -----------------------------------------
191
- # ChainSnapshotEncoder
192
- # -----------------------------------------
193
-
194
- # Broad chain-level market conditions.
195
- [timestamp, 'ChainSnapshot', relative_ts, native_token_price_usd, gas_fee]
196
-
197
- # Launchpad market regime (using absolute, log-normalized values).
198
- [timestamp, 'Lighthouse_Snapshot', relative_ts, protocol_id, timeframe_id, total_volume, total_transactions, total_traders, total_tokens_created, total_migrations]
199
-
200
- # -----------------------------------------
201
- # TokenTrendingListEncoder
202
- # -----------------------------------------
203
-
204
- # Fires *per token* on a trending list. The high-attention "meta" signal.
205
- [timestamp, 'TrendingToken', relative_ts, <TokenVibeEmbedding_of_trending_token>, list_source_id, timeframe_id, rank]
206
-
207
- # Fires *per token* on the boosted list.
208
- [timestamp, 'BoostedToken', relative_ts, <TokenVibeEmbedding_of_boosted_token>, total_boost_amount, rank]
209
-
210
- # -----------------------------------------
211
- # LaunchpadTheadEncoder
212
- # -----------------------------------------
213
-
214
- # On-platform social signal (Pump.fun comments).
215
- [timestamp, 'PumpReply', relative_ts, <UserWalletEmbedding>, <ReplyTextEmbedding>]
216
-
217
- # -----------------------------------------
218
- # CTEncoder
219
- # -----------------------------------------
220
-
221
- # Off-platform social signal (Twitter).
222
- [timestamp, 'XPost', relative_ts, <AuthorWalletEmbedding>, <PostTextEmbedding>, <MediaEmbedding>]
223
- [timestamp, 'XRetweet', relative_ts, <RetweeterWalletEmbedding>, <OriginalAuthorWalletEmbedding>, <OriginalPostTextEmbedding>, <OriginalPostMediaEmbedding>]
224
- [timestamp, 'XReply', relative_ts, <AuthorWalletEmbedding>, <PostTextEmbedding>, <MediaEmbedding>, <MainTweetEmbedding>]
225
- [timestamp, 'XQuoteTweet', relative_ts, <QuoterWalletEmbedding>, <QuoterTextEmbedding>, <OriginalAuthorWalletEmbedding>, <OriginalPostTextEmbedding>, <OriginalPostMediaEmbedding>]
226
-
227
- # -----------------------------------------
228
- # GlobalTrendingEncoder
229
- # -----------------------------------------
230
-
231
- # Broader cultural trend signal (TikTok).
232
- [timestamp, 'TikTok_Trending_Hashtag', relative_ts, <HashtagNameEmbedding>, rank]
233
-
234
- # Broader cultural trend signal (Twitter).
235
- [timestamp, 'XTrending_Hashtag', relative_ts, <HashtagNameEmbedding>, rank]
236
-
237
- # -----------------------------------------
238
- # TrackerEncoder
239
- # -----------------------------------------
240
-
241
- # Retail marketing signal (Paid groups).
242
- [timestamp, 'AlphaGroup_Call', relative_ts, group_id]
243
-
244
- [timestamp, 'Call_Channel', relative_ts, channel_id]
245
-
246
- # High-impact catalyst event.
247
- [timestamp, 'CexListing', relative_ts, exchange_id]
248
-
249
- # High-impact catalyst event.
250
- [timestamp, 'Migrated', relative_ts, protocol_id]
251
-
252
- # -----------------------------------------
253
- # Dex Encoder
254
- # -----------------------------------------
255
-
256
- [timestamp, 'DexBoost_Paid', relative_ts, amount, total_amount_on_token]
257
-
258
- [timestamp, 'DexProfile_Updated', relative_ts, has_changed_website_flag, has_changed_twitter_flag, has_changed_telegram_flag, has_changed_description_flag, <WebsiteEmbedding>, <TwitterLinkEmbedding>, <NewDescriptionEmbeeded>]
259
-
260
- ### **Global Context Injection**
261
-
262
- <PRELAUNCH> <LAUNCH> <Middle> <RECENT>
263
-
264
- ### **Token Role Embedding**
265
-
266
- <TokenVibeEmbedding_of_Token_A> + Subject_Token_Role
267
-
268
- <TokenVibeEmbedding_of_Token_B> + Trending_Token_Role
269
-
270
- <QuoteTokenVibeEmbedding_of_USDC> + Quote_Token_Role
271
-
272
-
273
- # **Links**
274
-
275
- ### `TransferLink`
276
-
277
- ```
278
- ['signature', 'source', 'destination', 'mint', 'timestamp']
279
  ```
280
 
281
- -----
282
-
283
- ### `BundleTradeLink`
284
-
285
- ```
286
- ['signatures', 'wallet_a', 'wallet_b', 'mint', 'slot', 'timestamp']
287
- ```
288
-
289
- -----
290
-
291
- ### `CopiedTradeLink`
292
-
293
- ```
294
- ['leader_buy_sig', 'leader_sell_sig', 'follower_buy_sig', 'follower_sell_sig', 'follower', 'leader', 'mint', 'time_gap_on_buy_sec', 'time_gap_on_sell_sec', 'leader_pnl', 'follower_pnl', 'leader_buy_total', 'leader_sell_total', 'follower_buy_total', 'follower_sell_total', 'follower_buy_slippage', 'follower_sell_slippage']
295
- ```
296
-
297
- -----
298
-
299
- ### `CoordinatedActivityLink`
300
-
301
- ```
302
- ['leader_first_sig', 'leader_second_sig', 'follower_first_sig', 'follower_second_sig', 'follower', 'leader', 'mint', 'time_gap_on_first_sec', 'time_gap_on_second_sec']
303
- ```
304
-
305
- -----
306
-
307
- ### `MintedLink`
308
-
309
- ```
310
- ['signature', 'timestamp', 'buy_amount']
311
- ```
312
-
313
- -----
314
-
315
- ### `SnipedLink`
316
-
317
- ```
318
- ['signature', 'rank', 'sniped_amount']
319
- ```
320
-
321
- -----
322
-
323
- ### `LockedSupplyLink`
324
-
325
- ```
326
- ['signature', 'amount', 'unlock_timestamp']
327
- ```
328
-
329
- -----
330
-
331
- ### `BurnedLink`
332
-
333
- ```
334
- ['signature', 'amount', 'timestamp']
335
- ```
336
-
337
- -----
338
-
339
- ### `ProvidedLiquidityLink`
340
-
341
  ```
342
- ['signature', 'wallet', 'token', 'pool_address', 'amount_base', 'amount_quote', 'timestamp']
343
- ```
344
-
345
- -----
346
-
347
- ### `WhaleOfLink`
348
-
349
- ```
350
- ['wallet', 'token', 'holding_pct_at_creation', 'ath_usd_at_creation']
351
- ```
352
-
353
- -----
354
-
355
- ### `TopTraderOfLink`
356
-
357
- ```
358
- ['wallet', 'token', 'pnl_at_creation', 'ath_usd_at_creation']
359
- ```
360
-
361
-
362
-
363
-
364
- /////
365
-
366
- def __gettestitem__(self, idx: int) -> Dict[str, Any]:
367
- """
368
- Generates a single complex data item, structured for the MemecoinCollator.
369
- NOTE: This currently returns the same mock data regardless of `idx`.
370
- """
371
- # --- 1. Setup Pooler and Define Raw Data ---
372
- pooler = EmbeddingPooler()
373
-
374
- # --- 5. Create Mock Raw Batch Data (FIXED) ---
375
- print("Creating mock raw batch...")
376
-
377
- # (Wallet profiles, socials, holdings definitions are unchanged)
378
- profile1 = {
379
- 'wallet_address': 'addrW1', 'age': 1.5e7, 'balance': 10.5,
380
- 'deployed_tokens_count': 2, 'deployed_tokens_migrated_pct': 0.5, 'deployed_tokens_avg_lifetime_sec': 36000.0, 'deployed_tokens_avg_peak_mc_usd': 100000.0, 'deployed_tokens_median_peak_mc_usd': 50000.0,
381
- 'transfers_in_count': 10, 'transfers_out_count': 5, 'spl_transfers_in_count': 20, 'spl_transfers_out_count': 15,
382
- 'total_buys_count': 50, 'total_sells_count': 40, 'total_winrate': 0.6,
383
- 'stats_1d_realized_profit_sol': 1.2, 'stats_1d_realized_profit_pnl': 0.1, 'stats_1d_buy_count': 5, 'stats_1d_sell_count': 3, 'stats_1d_transfer_in_count': 2, 'stats_1d_transfer_out_count': 1, 'stats_1d_avg_holding_period': 3600, 'stats_1d_total_bought_cost_sol': 10.0, 'stats_1d_total_sold_income_sol': 11.2, 'stats_1d_total_fee': 0.1, 'stats_1d_winrate': 0.7, 'stats_1d_tokens_traded': 4,
384
- 'stats_7d_realized_profit_sol': 5.0, 'stats_7d_realized_profit_pnl': 0.2, 'stats_7d_buy_count': 20, 'stats_7d_sell_count': 15, 'stats_7d_transfer_in_count': 8, 'stats_7d_transfer_out_count': 4, 'stats_7d_avg_holding_period': 7200, 'stats_7d_total_bought_cost_sol': 40.0, 'stats_7d_total_sold_income_sol': 45.0, 'stats_7d_total_fee': 0.5, 'stats_7d_winrate': 0.65, 'stats_7d_tokens_traded': 10,
385
- }
386
- social1 = {'has_pf_profile': True, 'has_twitter': True, 'has_telegram': False, 'is_exchange_wallet': False, 'username': 'trader_one'}
387
- holdings1 = [
388
- {'mint_address': 'tknA', 'holding_time': 3600.0, 'realized_profit_sol': 5.2, 'total_priority_fees': 0.05, 'balance_pct_to_supply': 0.01, 'history_bought_amount_sol': 10, 'bought_amount_sol_pct_to_native_balance': 0.5, 'history_total_buys': 5, 'history_total_sells': 2, 'realized_profit_pnl': 0.52, 'history_transfer_in': 1, 'history_transfer_out': 0, 'avarage_trade_gap_seconds': 300},
389
- ]
390
- profile2 = {
391
- 'wallet_address': 'addrW2', 'age': 1e6, 'balance': 1.0,
392
- 'deployed_tokens_count': 0, 'deployed_tokens_migrated_pct': 0.0, 'deployed_tokens_avg_lifetime_sec': 0.0, 'deployed_tokens_avg_peak_mc_usd': 0.0, 'deployed_tokens_median_peak_mc_usd': 0.0,
393
- 'transfers_in_count': 1, 'transfers_out_count': 0, 'spl_transfers_in_count': 0, 'spl_transfers_out_count': 0,
394
- 'total_buys_count': 0, 'total_sells_count': 0, 'total_winrate': 0.0,
395
- 'stats_1d_realized_profit_sol': 0.0, 'stats_1d_realized_profit_pnl': 0.0, 'stats_1d_buy_count': 0, 'stats_1d_sell_count': 0, 'stats_1d_transfer_in_count': 0, 'stats_1d_transfer_out_count': 0, 'stats_1d_avg_holding_period': 0, 'stats_1d_total_bought_cost_sol': 0.0, 'stats_1d_total_sold_income_sol': 0.0, 'stats_1d_total_fee': 0.0, 'stats_1d_winrate': 0.0, 'stats_1d_tokens_traded': 0,
396
- 'stats_7d_realized_profit_sol': 0.0, 'stats_7d_realized_profit_pnl': 0.0, 'stats_7d_buy_count': 0, 'stats_7d_sell_count': 0, 'stats_7d_transfer_in_count': 0, 'stats_7d_transfer_out_count': 0, 'stats_7d_avg_holding_period': 0, 'stats_7d_total_bought_cost_sol': 0.0, 'stats_7d_total_sold_income_sol': 0.0, 'stats_7d_total_fee': 0.0, 'stats_7d_winrate': 0.0, 'stats_7d_tokens_traded': 0,
397
- }
398
- social2 = {'has_pf_profile': False, 'has_twitter': False, 'has_telegram': False, 'is_exchange_wallet': True, 'username': 'cex_wallet'}
399
- holdings2 = []
400
-
401
-
402
- # Define raw data and get their indices
403
- tokenA_data = {
404
- 'address_emb_idx': pooler.get_idx('tknA'),
405
- 'name_emb_idx': pooler.get_idx('Token A'),
406
- 'symbol_emb_idx': pooler.get_idx('TKA'),
407
- 'image_emb_idx': pooler.get_idx(Image.new('RGB',(256,256), color='blue')),
408
- 'protocol': 1
409
- }
410
- # Add wallet usernames to the pool
411
- wallet1_user_idx = pooler.get_idx(social1['username'])
412
- wallet2_user_idx = pooler.get_idx(social2['username'])
413
- social1['username_emb_idx'] = wallet1_user_idx
414
- social2['username_emb_idx'] = wallet2_user_idx
415
- # --- NEW: Add a third wallet for social tests ---
416
- social3 = {'has_pf_profile': False, 'has_twitter': True, 'has_telegram': True, 'is_exchange_wallet': False, 'username': 'social_butterfly'}
417
- wallet3_user_idx = pooler.get_idx(social3['username'])
418
- social3['username_emb_idx'] = wallet3_user_idx
419
-
420
- # Create the final pre-computed data structures
421
- tokenB_data = {
422
- 'address_emb_idx': pooler.get_idx('tknA'),
423
- 'name_emb_idx': pooler.get_idx('Token A'),
424
- 'symbol_emb_idx': pooler.get_idx('TKA'),
425
- 'image_emb_idx': pooler.get_idx(Image.new('RGB',(256,256), color='blue')),
426
- 'protocol': 1
427
- }
428
-
429
- tokenC_data = {
430
- 'address_emb_idx': pooler.get_idx('tknA'),
431
- 'name_emb_idx': pooler.get_idx('Token A'),
432
- 'symbol_emb_idx': pooler.get_idx('TKA'),
433
- 'image_emb_idx': pooler.get_idx(Image.new('RGB',(256,256), color='blue')),
434
- 'protocol': 1
435
- }
436
-
437
- tokenD_data = {
438
- 'address_emb_idx': pooler.get_idx('tknA'),
439
- 'name_emb_idx': pooler.get_idx('Token A'),
440
- 'symbol_emb_idx': pooler.get_idx('TKA'),
441
- 'image_emb_idx': pooler.get_idx(Image.new('RGB',(256,256), color='blue')),
442
- 'protocol': 1
443
- }
444
-
445
- item = {
446
- 'event_sequence': [
447
- {'event_type': 'XPost', # NEW
448
- 'timestamp': 1729711350,
449
- 'relative_ts': -25,
450
- 'wallet_address': 'addrW1', # Author
451
- 'text_emb_idx': pooler.get_idx('This is the main tweet about $TKA'),
452
- 'media_emb_idx': pooler.get_idx(Image.new('RGB', (100,100), color='cyan'))
453
- },
454
- {'event_type': 'XReply', # NEW
455
- 'timestamp': 1729711360,
456
- 'relative_ts': -35,
457
- 'wallet_address': 'addrW2', # Replier
458
- 'text_emb_idx': pooler.get_idx('This is a reply to the main tweet'),
459
- 'media_emb_idx': pooler.get_idx(None), # No media in reply
460
- 'main_tweet_text_emb_idx': pooler.get_idx('This is the main tweet about $TKA')
461
- },
462
- {'event_type': 'XRetweet', # NEW
463
- 'timestamp': 1729711370,
464
- 'relative_ts': -40,
465
- 'wallet_address': 'addrW3', # The retweeter
466
- 'original_author_wallet_address': 'addrW1', # The original author
467
- 'original_post_text_emb_idx': pooler.get_idx('This is the main tweet about $TKA'),
468
- 'original_post_media_emb_idx': pooler.get_idx(Image.new('RGB', (100,100), color='cyan'))
469
- },
470
- # --- CORRECTED: Test a pre-launch event with negative relative_ts ---
471
- {'event_type': 'Transfer',
472
- 'timestamp': 1729711180,
473
- 'relative_ts': -10, # Negative relative_ts indicates pre-launch
474
- 'wallet_address': 'addrW2',
475
- 'destination_wallet_address': 'addrW1',
476
- 'token_address': 'tknA',
477
- 'token_amount': 1000.0, 'transfer_pct_of_total_supply': 0.0, 'transfer_pct_of_holding': 0.0, 'priority_fee': 0.0
478
- },
479
- {'event_type': 'Mint', 'timestamp': 1729711190, 'relative_ts': 0, 'wallet_address': 'addrW1', 'token_address': 'tknA'},
480
- {'event_type': 'Chart_Segment', 'timestamp': 1729711200, 'relative_ts': 60, 'opens': [1.0]*OHLC_SEQ_LEN, 'closes': [1.1]*OHLC_SEQ_LEN, 'i': '1s'}, # This is high-def (segment 0) by default
481
- {'event_type': 'Chart_Segment', 'timestamp': 1729711260, 'relative_ts': 120, 'opens': [1.0]*OHLC_SEQ_LEN, 'closes': [1.1]*OHLC_SEQ_LEN, 'i': '1s'}, # You can mark this as blurry
482
- {'event_type': 'Transfer',
483
- 'timestamp': 1729711210,
484
- 'relative_ts': 20,
485
- 'wallet_address': 'addrW1', # Source
486
- 'destination_wallet_address': 'addrW2', # Destination
487
- 'token_address': 'tknA', # Need token for context? (Optional, depends on design)
488
- 'token_amount': 500.0,
489
- 'transfer_pct_of_total_supply': 0.005,
490
- 'transfer_pct_of_holding': 0.1,
491
- 'priority_fee': 0.0001
492
- },
493
- {'event_type': 'Trade',
494
- 'timestamp': 1729711220,
495
- 'relative_ts': 30,
496
- 'wallet_address': 'addrW1',
497
- 'token_address': 'tknA',
498
- 'trade_direction': 0,
499
- 'sol_amount': 0.5,
500
- # --- FIXED: Pass the integer ID directly ---
501
- 'dex_platform_id': vocab.DEX_TO_ID['Axiom'],
502
- 'priority_fee': 0.0002,
503
- 'mev_protection': False,
504
- 'token_amount_pct_of_holding': 0.05, 'quote_amount_pct_of_holding': 0.02,
505
- 'slippage': 0.01, 'price_impact': 0.005, 'success': True, 'is_bundle': False, 'total_usd': 75.0
506
- },
507
- {'event_type': 'Deployer_Trade', # NEW: Testing a trade variant
508
- 'timestamp': 1729711230,
509
- 'relative_ts': 40,
510
- 'wallet_address': 'addrW1', # The creator wallet
511
- 'token_address': 'tknA',
512
- 'trade_direction': 1, 'sol_amount': 0.2,
513
- # --- FIXED: Pass the integer ID directly ---
514
- 'dex_platform_id': vocab.DEX_TO_ID['Trojan'],
515
- 'priority_fee': 0.0005,
516
- 'mev_protection': True,
517
- 'token_amount_pct_of_holding': 0.1, 'quote_amount_pct_of_holding': 0.0,
518
- 'slippage': 0.02, 'price_impact': 0.01, 'success': True, 'is_bundle': False, 'total_usd': 30.0
519
- },
520
- {'event_type': 'SmartWallet_Trade', # NEW
521
- 'timestamp': 1729711240,
522
- 'relative_ts': 50,
523
- 'wallet_address': 'addrW1', # A known smart wallet
524
- 'token_address': 'tknA',
525
- 'trade_direction': 0, 'sol_amount': 1.5,
526
- # --- FIXED: Pass the integer ID directly ---
527
- 'dex_platform_id': vocab.DEX_TO_ID['Axiom'],
528
- 'priority_fee': 0.001,
529
- 'mev_protection': True,
530
- 'token_amount_pct_of_holding': 0.2, 'quote_amount_pct_of_holding': 0.1,
531
- 'slippage': 0.01, 'price_impact': 0.008, 'success': True, 'is_bundle': False, 'total_usd': 225.0
532
- },
533
- {'event_type': 'LargeTrade', # NEW
534
- 'timestamp': 1729711250,
535
- 'relative_ts': 60,
536
- 'wallet_address': 'addrW2', # Some other wallet
537
- 'token_address': 'tknA',
538
- 'trade_direction': 0, 'sol_amount': 10.0,
539
- # --- FIXED: Pass the integer ID directly ---
540
- 'dex_platform_id': vocab.DEX_TO_ID['OXK'],
541
- 'priority_fee': 0.002,
542
- 'mev_protection': False,
543
- 'token_amount_pct_of_holding': 0.8, 'quote_amount_pct_of_holding': 0.5,
544
- 'slippage': 0.03, 'price_impact': 0.05, 'success': True, 'is_bundle': False, 'total_usd': 1500.0
545
- },
546
- {'event_type': 'Chart_Segment', 'timestamp': 1729711260, 'relative_ts': 70, 'opens': [1.0]*OHLC_SEQ_LEN, 'closes': [1.1]*OHLC_SEQ_LEN, 'i': '1s'},
547
- {'event_type': 'PoolCreated', # NEW
548
- 'timestamp': 1729711270,
549
- 'relative_ts': 80,
550
- 'wallet_address': 'addrW1',
551
- 'protocol_id': vocab.PROTOCOL_TO_ID['Raydium CPMM'],
552
- 'quote_token_address': 'tknB',
553
- 'base_amount': 1000000.0,
554
- 'quote_amount': 10.0
555
- },
556
- {'event_type': 'LiquidityChange', # NEW
557
- 'timestamp': 1729711280,
558
- 'relative_ts': 90,
559
- 'wallet_address': 'addrW2',
560
- 'quote_token_address': 'tknB',
561
- 'change_type_id': 0, # 0 for 'add'
562
- 'quote_amount': 2.0
563
- },
564
- {'event_type': 'FeeCollected', # NEW
565
- 'timestamp': 1729711290,
566
- 'relative_ts': 100,
567
- 'wallet_address': 'addrW1', # The recipient (e.g., dev wallet)
568
- 'sol_amount': 0.1
569
- },
570
- {'event_type': 'TokenBurn', # NEW
571
- 'timestamp': 1729711300,
572
- 'relative_ts': 110,
573
- 'wallet_address': 'addrW2', # The burner wallet
574
- 'amount_pct_of_total_supply': 0.01, # 1% of supply
575
- 'amount_tokens_burned': 10000000.0
576
- },
577
- {'event_type': 'SupplyLock', # NEW
578
- 'timestamp': 1729711310,
579
- 'relative_ts': 120,
580
- 'wallet_address': 'addrW1', # The locker wallet
581
- 'amount_pct_of_total_supply': 0.10, # 10% of supply
582
- 'lock_duration': 2592000 # 30 days in seconds
583
- },
584
- {'event_type': 'HolderSnapshot', # NEW
585
- 'timestamp': 1729711320,
586
- 'relative_ts': 130,
587
- # This is a pointer to the pre-computed embedding
588
- # In a real system, this would be the index of the embedding
589
- 'holders': [ # Raw holder data
590
- {'wallet': 'addrW1', 'holding_pct': 0.15},
591
- {'wallet': 'addrW2', 'holding_pct': 0.05},
592
- # Add more mock holders if needed
593
- ]
594
- },
595
- {'event_type': 'OnChain_Snapshot', # NEW
596
- 'timestamp': 1729711320,
597
- 'relative_ts': 130,
598
- 'total_holders': 500,
599
- 'smart_traders': 25,
600
- 'kols': 3,
601
- 'holder_growth_rate': 0.15,
602
- 'top_10_holder_pct': 0.22,
603
- 'sniper_holding_pct': 0.05,
604
- 'rat_wallets_holding_pct': 0.02,
605
- 'bundle_holding_pct': 0.01,
606
- 'current_market_cap': 150000.0,
607
- 'volume': 50000.0,
608
- 'buy_count': 120,
609
- 'sell_count': 80,
610
- 'total_txns': 200,
611
- 'global_fees_paid': 1.5
612
- },
613
- {'event_type': 'TrendingToken', # NEW
614
- 'timestamp': 1729711330,
615
- 'relative_ts': 140,
616
- 'token_address': 'tknC', # The token that is trending
617
- 'list_source_id': vocab.TRENDING_LIST_SOURCE_TO_ID['Phantom'],
618
- 'timeframe_id': vocab.TRENDING_LIST_TIMEFRAME_TO_ID['1h'],
619
- 'rank': 3
620
- },
621
- {'event_type': 'BoostedToken', # NEW
622
- 'timestamp': 1729711340,
623
- 'relative_ts': 150,
624
- 'token_address': 'tknD', # The token that is boosted
625
- 'total_boost_amount': 5000.0,
626
- 'rank': 1
627
- },
628
- {'event_type': 'XQuoteTweet', # NEW
629
- 'timestamp': 1729711380,
630
- 'relative_ts': 190,
631
- 'wallet_address': 'addrW3', # The quoter
632
- 'quoter_text_emb_idx': pooler.get_idx('Wow, look at this! $TKA'),
633
- 'original_author_wallet_address': 'addrW1', # The original author
634
- 'original_post_text_emb_idx': pooler.get_idx('This is the main tweet about $TKA'),
635
- 'original_post_media_emb_idx': pooler.get_idx(Image.new('RGB', (100,100), color='cyan'))
636
- },
637
- # --- NEW: Add special context tokens ---
638
- {'event_type': 'MIDDLE', 'timestamp': 1729711500, 'relative_ts': 195},
639
- {'event_type': 'PumpReply', # NEW
640
- 'timestamp': 1729711390,
641
- 'relative_ts': 200,
642
- 'wallet_address': 'addrW2', # The user who replied
643
- 'reply_text_emb_idx': pooler.get_idx('to the moon!')
644
- },
645
- {'event_type': 'DexBoost_Paid', # NEW
646
- 'timestamp': 1729711400,
647
- 'relative_ts': 210,
648
- 'amount': 5.0, # e.g., 5 Boost
649
- 'total_amount_on_token': 25.0 # 25 Boost Points
650
- },
651
- {'event_type': 'DexProfile_Updated', # NEW
652
- 'timestamp': 1729711410,
653
- 'relative_ts': 220,
654
- 'has_changed_website_flag': True,
655
- 'has_changed_twitter_flag': False,
656
- 'has_changed_telegram_flag': True,
657
- 'has_changed_description_flag': True,
658
- # Pre-computed text embeddings
659
- 'website_emb_idx': pooler.get_idx('new-token-website.com'),
660
- 'twitter_link_emb_idx': pooler.get_idx('old_handle'), # No change, so old link
661
- 'telegram_link_emb_idx': pooler.get_idx('new_tg_group'),
662
- 'description_emb_idx': pooler.get_idx('This is the new and improved token description.')
663
- },
664
- {'event_type': 'AlphaGroup_Call', # NEW
665
- 'timestamp': 1729711420,
666
- 'relative_ts': 230,
667
- 'group_id': vocab.ALPHA_GROUPS_TO_ID['Potion']
668
- },
669
- {'event_type': 'Channel_Call', # NEW
670
- 'timestamp': 1729711430,
671
- 'relative_ts': 240,
672
- 'channel_id': vocab.CALL_CHANNELS_TO_ID['MarcosCalls']
673
- },
674
- {'event_type': 'RECENT', 'timestamp': 1729711510, 'relative_ts': 245},
675
- {'event_type': 'CexListing', # NEW
676
- 'timestamp': 1729711440,
677
- 'relative_ts': 250,
678
- 'exchange_id': vocab.EXCHANGES_TO_ID['mexc']
679
- },
680
- {'event_type': 'TikTok_Trending_Hashtag', # NEW
681
- 'timestamp': 1729711450,
682
- 'relative_ts': 260,
683
- 'hashtag_name_emb_idx': pooler.get_idx('CryptoTok'),
684
- 'rank': 5
685
- },
686
- {'event_type': 'XTrending_Hashtag', # NEW
687
- 'timestamp': 1729711460,
688
- 'relative_ts': 270,
689
- 'hashtag_name_emb_idx': pooler.get_idx('SolanaMemes'),
690
- 'rank': 2
691
- },
692
- {'event_type': 'ChainSnapshot', # NEW
693
- 'timestamp': 1729711470,
694
- 'relative_ts': 280,
695
- 'native_token_price_usd': 150.75,
696
- 'gas_fee': 0.00015 # Example gas fee
697
- },
698
- {'event_type': 'Lighthouse_Snapshot', # NEW
699
- 'timestamp': 1729711480,
700
- 'relative_ts': 290,
701
- 'protocol_id': vocab.PROTOCOL_TO_ID['Pump V1'],
702
- 'timeframe_id': vocab.LIGHTHOUSE_TIMEFRAME_TO_ID['1h'],
703
- 'total_volume': 1.2e6,
704
- 'total_transactions': 5000,
705
- 'total_traders': 1200,
706
- 'total_tokens_created': 85,
707
- 'total_migrations': 70
708
- },
709
- {'event_type': 'Migrated', # NEW
710
- 'timestamp': 1729711490,
711
- 'relative_ts': 300,
712
- 'protocol_id': vocab.PROTOCOL_TO_ID['Raydium CPMM']
713
- },
714
-
715
- ],
716
- 'wallets': {
717
- 'addrW1': {'profile': profile1, 'socials': social1, 'holdings': holdings1},
718
- 'addrW2': {'profile': profile2, 'socials': social2, 'holdings': holdings2},
719
- # --- NEW: Add wallet 3 data ---
720
- 'addrW3': {
721
- 'profile': {**profile2, 'wallet_address': 'addrW3'}, # Reuse profile2 but change address
722
- 'socials': social3,
723
- 'holdings': []
724
- }
725
- },
726
- 'tokens': {
727
- 'tknA': tokenA_data, # Main token
728
- 'tknB': tokenB_data, # Quote token
729
- 'tknC': tokenC_data, # Trending token
730
- 'tknD': tokenD_data # Boosted token
731
- },
732
- # --- NEW: The pre-computed embedding pool is generated after collecting all items
733
- 'embedding_pooler': pooler, # Pass the pooler to generate the tensor later
734
-
735
- # --- NEW: Expanded graph_links to test all encoders ---
736
- # --- FIXED: Removed useless logging fields as per user request ---
737
- 'graph_links': {
738
- 'TransferLink': {'links': [{'timestamp': 1729711205}], 'edges': [('addrW1', 'addrW2')]}, # Keep timestamp
739
- 'BundleTradeLink': {'links': [{'timestamp': 1729711215}], 'edges': [('addrW1', 'addrW2')]}, # Keep timestamp
740
- 'CopiedTradeLink': {'links': [
741
- {'time_gap_on_buy_sec': 10, 'time_gap_on_sell_sec': 120, 'leader_pnl': 5.0, 'follower_pnl': 4.0, 'follower_buy_total': 100, 'follower_sell_total': 120}
742
- ], 'edges': [('addrW1', 'addrW2')]},
743
- 'CoordinatedActivityLink': {'links': [
744
- {'time_gap_on_first_sec': 5, 'time_gap_on_second_sec': 8}
745
- ], 'edges': [('addrW1', 'addrW2')]},
746
- 'MintedLink': {'links': [
747
- {'timestamp': 1729711200, 'buy_amount': 1e9}
748
- ], 'edges': [('addrW1', 'tknA')]},
749
- 'SnipedLink': {'links': [
750
- {'rank': 1, 'sniped_amount': 5e8}
751
- ], 'edges': [('addrW1', 'tknA')]},
752
- 'LockedSupplyLink': {'links': [
753
- {'amount': 1e10} # Only amount is needed
754
- ], 'edges': [('addrW1', 'tknA')]},
755
- 'BurnedLink': {'links': [
756
- {'timestamp': 1729711300} # Only timestamp is needed
757
- ], 'edges': [('addrW2', 'tknA')]},
758
- 'ProvidedLiquidityLink': {'links': [
759
- {'timestamp': 1729711250} # Only timestamp is needed
760
- ], 'edges': [('addrW1', 'tknA')]},
761
- 'WhaleOfLink': {'links': [
762
- {} # Just the existence of the link is the feature
763
- ], 'edges': [('addrW1', 'tknA')]},
764
- 'TopTraderOfLink': {'links': [
765
- {'pnl_at_creation': 50000.0} # Only PnL is needed
766
- ], 'edges': [('addrW2', 'tknA')]}
767
- },
768
 
769
- # --- FIXED: Removed chart_segments dictionary ---
770
- 'labels': torch.randn(self.num_outputs) if self.num_outputs > 0 else torch.zeros(0),
771
- 'labels_mask': torch.ones(self.num_outputs) if self.num_outputs > 0 else torch.zeros(0)
772
- }
773
-
774
- print("Mock raw batch created.")
775
-
776
- return item
 
1
+ # Apollo: Oracle Model
 
 
 
2
 
3
+ ## Project Status
4
+ **Phase:** Hyperparameter Optimization & Dataset Preparation.
 
 
 
 
 
 
5
 
6
+ ### Recent Updates (Jan 2026)
7
+ * **Hyperparameter Tuning**: Analyzed token trade distribution to determine optimal model parameters.
8
+ * **Max Sequence Length**: Set to **8192**. This covers >2 hours of high-frequency trading activity for high-volume tokens (verified against `HWVY...`) and the full lifecycle for 99% of tokens.
9
+ * **Prediction Horizons**: Set to **60s, 3m, 5m, 10m, 30m, 1h, 2h**.
10
+ * **Min Horizon (60s)**: Chosen to accommodate ~20s inference latency while capturing the "meat" of aggressive breakout movers.
11
+ * **Max Horizon (2h)**: Covers the timeframe where 99% of tokens hit their All-Time High.
12
+ * **Infrastructure**:
13
+ * Updated `train.sh` to use these new hyperparameters.
14
+ * Updated `scripts/cache_dataset.py` to ensure cached datasets are labeled with these horizons.
15
+ * Verified `DataFetcher` retrieves full trade histories (no hidden limits).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ ## Configuration Summary
 
18
 
19
+ | Parameter | Value | Rationale |
20
+ | :--- | :--- | :--- |
21
+ | **Max Seq Len** | `8192` | Captures >2h of intense pump activity or full rug lifecycle. |
22
+ | **Horizons** | `60, 180, 300, 600, 1800, 3600, 7200` | From "Scalp/Breakout" (1m) to "Runner/ATH" (2h). |
23
+ | **Inference Latency** | ~20s | Dictates the 60s minimum horizon. |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ ## Usage
 
26
 
27
+ ### 1. Cache Dataset
28
+ Pre-process data into `.pt` files with correct labels.
29
+ ```bash
30
+ ./pre_cache.sh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  ```
32
 
33
+ ### 2. Train Model
34
+ Launch training with updated hyperparameters.
35
+ ```bash
36
+ ./train.sh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ ## TODOs
40
+ * [ ] **Re-run Caching**: Since horizons changed, the existing cache (if any) is stale. Expected to run `pre_cache.sh`.
41
+ * [ ] **Verify Inference**: Ensure `inference.py` handles the 20s latency constraints gracefully (e.g. timestamp checks).
42
+ * [ ] **Model Architecture**: Confirm `8192` context length fits in VRAM with current model config (Attention implementation).
 
 
 
 
data/data_loader.py CHANGED
@@ -10,6 +10,7 @@ from typing import List, Dict, Any, Optional, Union, Tuple
10
  from pathlib import Path
11
  import numpy as np
12
  from bisect import bisect_left, bisect_right
 
13
 
14
  # We need the vocabulary for IDs and the processor for the pooler
15
  import models.vocabulary as vocab
@@ -136,12 +137,59 @@ class OracleDataset(Dataset):
136
  self.cached_files = sorted(self.cache_dir.glob("sample_*.pt"), key=lambda p: int(p.stem.split('_')[1]))
137
  if not self.cached_files:
138
  raise RuntimeError(f"Cache directory '{self.cache_dir}' provided but contains no 'sample_*.pt' files.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
 
140
  self.num_samples = len(self.cached_files)
 
141
  if max_samples is not None:
142
  self.num_samples = min(max_samples, self.num_samples)
143
  self.cached_files = self.cached_files[:self.num_samples]
144
- print(f"INFO: Found {self.num_samples} cached samples to use.")
 
 
145
  self.sampled_mints = [] # Not needed in cached mode
146
  self.available_mints = []
147
 
@@ -201,6 +249,12 @@ class OracleDataset(Dataset):
201
  def __len__(self) -> int:
202
  return self.num_samples
203
 
 
 
 
 
 
 
204
  def _normalize_price_series(self, values: List[float]) -> List[float]:
205
  if not values:
206
  return values
@@ -874,26 +928,27 @@ class OracleDataset(Dataset):
874
  if not trade_ts_values:
875
  return None
876
 
877
- first_trade_ts = min(trade_ts_values)
878
- last_trade_ts = max(trade_ts_values)
879
- available_duration = last_trade_ts - mint_ts_value
880
- if available_duration <= 0:
881
- return None
882
- if available_duration < (min_window + min_label):
883
- return None
884
-
885
- required_horizon = preferred_horizon if available_duration >= (min_window + preferred_horizon) else min_label
886
- upper_bound = max(0.0, available_duration - required_horizon)
887
- lower_bound = max(min_window, int(max(0.0, first_trade_ts - mint_ts_value)))
888
 
889
- if upper_bound < lower_bound:
890
- return None
891
- if upper_bound == lower_bound:
892
- sample_offset = lower_bound
 
 
 
 
 
 
 
893
  else:
894
- sample_offset = random.randint(lower_bound, int(upper_bound))
 
895
 
896
- T_cutoff = mint_timestamp + datetime.timedelta(seconds=int(sample_offset))
897
 
898
  token_address = raw_data['token_address']
899
  creator_address = raw_data['creator_address']
@@ -1031,7 +1086,7 @@ class OracleDataset(Dataset):
1031
  max_horizon_seconds=self.max_cache_horizon_seconds,
1032
  include_wallet_data=False,
1033
  include_graph=False,
1034
- min_trades=50,
1035
  full_history=True, # Bypass H/B/H limits
1036
  prune_failed=True, # Drop failed trades
1037
  prune_transfers=True # Drop transfers (captured in snapshots)
@@ -1436,20 +1491,76 @@ class OracleDataset(Dataset):
1436
  event_sequence = [entry[1] for entry in event_sequence_entries]
1437
 
1438
  # 8. Compute Labels using future data
1439
- labels = torch.zeros(0)
1440
- labels_mask = torch.zeros(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1441
 
1442
- # NEED TO IMPORT OR REFIND future_trades_for_labels LOGIC
1443
- # We need logic to compute future returns
1444
- # For now, placeholder or port the logic
1445
 
1446
- # 9. Return Item
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1447
  return {
1448
  'event_sequence': event_sequence,
1449
  'wallets': wallet_data,
1450
  'tokens': all_token_data,
1451
  'graph_links': graph_links,
1452
  'embedding_pooler': pooler,
1453
- 'labels': labels,
1454
- 'labels_mask': labels_mask
1455
  }
 
10
  from pathlib import Path
11
  import numpy as np
12
  from bisect import bisect_left, bisect_right
13
+ import json
14
 
15
  # We need the vocabulary for IDs and the processor for the pooler
16
  import models.vocabulary as vocab
 
137
  self.cached_files = sorted(self.cache_dir.glob("sample_*.pt"), key=lambda p: int(p.stem.split('_')[1]))
138
  if not self.cached_files:
139
  raise RuntimeError(f"Cache directory '{self.cache_dir}' provided but contains no 'sample_*.pt' files.")
140
+
141
+ # --- NEW: Strict Metadata & Weighting ---
142
+ metadata_path = self.cache_dir / "metadata.jsonl"
143
+ if not metadata_path.exists():
144
+ raise RuntimeError(f"FATAL: metadata.jsonl not found in {self.cache_dir}. Cannot train without class-balanced sampling.")
145
+
146
+ print(f"INFO: Loading metadata from {metadata_path}...")
147
+ file_class_map = {}
148
+ class_counts = defaultdict(int)
149
+
150
+ with open(metadata_path, 'r') as f:
151
+ for line in f:
152
+ try:
153
+ entry = json.loads(line)
154
+ fname = entry['file']
155
+ cid = entry['class_id']
156
+ file_class_map[fname] = cid
157
+ class_counts[cid] += 1
158
+ except Exception as e:
159
+ print(f"WARN: Failed to parse metadata line: {e}")
160
+
161
+ print(f"INFO: Class Distribution: {dict(class_counts)}")
162
+
163
+ # Compute Weights
164
+ self.weights_list = []
165
+ valid_files = []
166
+
167
+ # We iterate properly sorted cached files to align with __getitem__ index
168
+ for p in self.cached_files:
169
+ fname = p.name
170
+ if fname not in file_class_map:
171
+ # Should be fatal if strict, but maybe some files were skipped?
172
+ # If file exists but no metadata, we can't weight it properly.
173
+ # Current pipeline writes metadata only for successful caches.
174
+ # So if it's in cached_files but not metadata, it might be a stale file.
175
+ print(f"WARN: File {fname} found in cache but missing metadata. Skipping.")
176
+ continue
177
+
178
+ cid = file_class_map[fname]
179
+ count = class_counts[cid]
180
+ weight = 1.0 / count if count > 0 else 0.0
181
+ self.weights_list.append(weight)
182
+ valid_files.append(p)
183
 
184
+ self.cached_files = valid_files
185
  self.num_samples = len(self.cached_files)
186
+
187
  if max_samples is not None:
188
  self.num_samples = min(max_samples, self.num_samples)
189
  self.cached_files = self.cached_files[:self.num_samples]
190
+ self.weights_list = self.weights_list[:self.num_samples]
191
+
192
+ print(f"INFO: Weighted Dataset Ready. {self.num_samples} samples.")
193
  self.sampled_mints = [] # Not needed in cached mode
194
  self.available_mints = []
195
 
 
249
  def __len__(self) -> int:
250
  return self.num_samples
251
 
252
+ def get_weights(self) -> torch.DoubleTensor:
253
+ """Returns the sampling weights for the dataset."""
254
+ if hasattr(self, 'weights_list') and self.weights_list:
255
+ return torch.as_tensor(self.weights_list, dtype=torch.double)
256
+ return None
257
+
258
  def _normalize_price_series(self, values: List[float]) -> List[float]:
259
  if not values:
260
  return values
 
928
  if not trade_ts_values:
929
  return None
930
 
931
+ # Cache guarantees min_trades=25, so we proceed assuming valid data.
932
+ # But for safety in dynamic sampling:
933
+ if not trade_ts_values:
934
+ return None
 
 
 
 
 
 
 
935
 
936
+ # Sort trades to find the 24th trade timestamp
937
+ sorted_trades_ts = sorted(trade_ts_values)
938
+
939
+ # T_start = Timestamp of the 25th trade (index 24)
940
+ # If somehow we have fewer than 25 trades (cache mismatch?), fallback to last.
941
+ safe_idx = min(24, len(sorted_trades_ts) - 1)
942
+ min_cutoff_ts = sorted_trades_ts[safe_idx]
943
+ max_cutoff_ts = sorted_trades_ts[-1]
944
+
945
+ if max_cutoff_ts <= min_cutoff_ts:
946
+ sample_offset_ts = min_cutoff_ts
947
  else:
948
+ # Standard case: sample uniformly between [Trade[24], LastTrade]
949
+ sample_offset_ts = random.uniform(min_cutoff_ts, max_cutoff_ts)
950
 
951
+ T_cutoff = datetime.datetime.fromtimestamp(sample_offset_ts, tz=datetime.timezone.utc)
952
 
953
  token_address = raw_data['token_address']
954
  creator_address = raw_data['creator_address']
 
1086
  max_horizon_seconds=self.max_cache_horizon_seconds,
1087
  include_wallet_data=False,
1088
  include_graph=False,
1089
+ min_trades=25,
1090
  full_history=True, # Bypass H/B/H limits
1091
  prune_failed=True, # Drop failed trades
1092
  prune_transfers=True # Drop transfers (captured in snapshots)
 
1491
  event_sequence = [entry[1] for entry in event_sequence_entries]
1492
 
1493
  # 8. Compute Labels using future data
1494
+ # Define horizons (e.g., [60, 120, ...])
1495
+ horizons = sorted(self.horizons_seconds)
1496
+
1497
+ # Pre-sort future trades for efficient searching
1498
+ # Note: future_trades_for_labels contains ALL trades (past & future relative to T_cutoff)
1499
+ # We need to find the price at T_cutoff and at T_cutoff + h
1500
+
1501
+ all_trades = future_trades_for_labels
1502
+ # Ensure sorted
1503
+ all_trades.sort(key=lambda x: _timestamp_to_order_value(x['timestamp']))
1504
+
1505
+ # Find price at T_cutoff (Current Price)
1506
+ # It's the last trade before or at T_cutoff
1507
+ current_price = 0.0
1508
+ cutoff_ts_val = T_cutoff.timestamp()
1509
+ last_trade_ts_val = _timestamp_to_order_value(all_trades[-1]['timestamp'])
1510
+
1511
+ # Find index of last trade <= T_cutoff
1512
+ # We can use binary search or simple iteration since we are building dataset
1513
+ # Iterating is safer for complex logic
1514
+ current_price_idx = -1
1515
+ for i, t in enumerate(all_trades):
1516
+ if _timestamp_to_order_value(t['timestamp']) <= cutoff_ts_val:
1517
+ current_price = float(t['price_usd'])
1518
+ current_price_idx = i
1519
+ else:
1520
+ break
1521
 
1522
+ label_values = []
1523
+ mask_values = []
 
1524
 
1525
+ for h in horizons:
1526
+ target_ts = cutoff_ts_val + h
1527
+
1528
+ if target_ts > last_trade_ts_val:
1529
+ # Horizon extends beyond known history
1530
+ # We MASK this label. We do NOT guess 0.
1531
+ label_values.append(0.0) # Dummy value
1532
+ mask_values.append(0.0) # Mask = 0 (Ignore)
1533
+ else:
1534
+ # Find price at target_ts
1535
+ # It is the last trade strictly before or at target_ts
1536
+ future_price = current_price # Default to current if no trades found in window? Unlikely if checked range.
1537
+
1538
+ # Check trades between current_idx and target
1539
+ # Optimization: start search from current_price_idx
1540
+ found_future = False
1541
+ for j in range(current_price_idx, len(all_trades)):
1542
+ t = all_trades[j]
1543
+ t_ts = _timestamp_to_order_value(t['timestamp'])
1544
+ if t_ts <= target_ts:
1545
+ future_price = float(t['price_usd'])
1546
+ found_future = True
1547
+ else:
1548
+ break # Optimization: surpassed target_ts
1549
+
1550
+ if current_price > 0:
1551
+ ret = (future_price - current_price) / current_price
1552
+ else:
1553
+ ret = 0.0
1554
+
1555
+ label_values.append(ret)
1556
+ mask_values.append(1.0) # Mask = 1 (Valid)
1557
+
1558
  return {
1559
  'event_sequence': event_sequence,
1560
  'wallets': wallet_data,
1561
  'tokens': all_token_data,
1562
  'graph_links': graph_links,
1563
  'embedding_pooler': pooler,
1564
+ 'labels': torch.tensor(label_values, dtype=torch.float32),
1565
+ 'labels_mask': torch.tensor(mask_values, dtype=torch.float32)
1566
  }
events.md ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =========================================
2
+ # Entity Encoders
3
+ # =========================================
4
+ # These are generated offline/streaming and are the "vocabulary" for the model.
5
+
6
+ <WalletEmbedding> # Embedding of a wallet's relationships, behavior, and history.
7
+ <WalletEmbedding> = [
8
+ // Data from the 'wallet_profiles' table (Wallet-level lifetime and daily/weekly stats)
9
+ wallet_profiles_row: [
10
+ // Core Info & Timestamps
11
+ age, // No Contextual
12
+ wallet_address, // Primary wallet identifier
13
+
14
+
15
+ // 7. NEW: Deployed Token Aggregates (8 Features)
16
+ deployed_tokens_count, // Total tokens created
17
+ deployed_tokens_migrated_pct, // % that migrated
18
+ deployed_tokens_avg_lifetime_sec, // Avg duration before dev selling
19
+ deployed_tokens_avg_peak_mc_usd, // Avg peak marketcap
20
+ deployed_tokens_median_peak_mc_usd,
21
+
22
+ // Metadata & Balances
23
+ balance, // Current SOL balance
24
+
25
+ // Lifetime Transaction Counts (Total history)
26
+ transfers_in_count, // Total native transfers received
27
+ transfers_out_count, // Total native transfers sent
28
+ spl_transfers_in_count, // Total SPL token transfers received
29
+ spl_transfers_out_count,// Total SPL token transfers sent
30
+
31
+ // Lifetime Trading Stats (Total history)
32
+ total_buys_count, // Total buys across all tokens
33
+ total_sells_count, // Total sells across all tokens
34
+ total_winrate, // Overall trading winrate
35
+
36
+ // 1-Day Stats (Realized P&L, Counts, Averages, Volume, Fees, Winrate)
37
+ stats_1d_realized_profit_sol,
38
+ stats_1d_realized_profit_pnl,
39
+ stats_1d_buy_count,
40
+ stats_1d_sell_count,
41
+ stats_1d_transfer_in_count,
42
+ stats_1d_transfer_out_count,
43
+ stats_1d_avg_holding_period,
44
+ stats_1d_total_bought_cost_sol,
45
+ stats_1d_total_sold_income_sol,
46
+ stats_1d_total_fee,
47
+ stats_1d_winrate,
48
+ stats_1d_tokens_traded,
49
+
50
+ // 7-Day Stats (Realized P&L, Counts, Averages, Volume, Fees, Winrate)
51
+ stats_7d_realized_profit_sol,
52
+ stats_7d_realized_profit_pnl,
53
+ stats_7d_buy_count,
54
+ stats_7d_sell_count,
55
+ stats_7d_transfer_in_count,
56
+ stats_7d_transfer_out_count,
57
+ stats_7d_avg_holding_period,
58
+ stats_7d_total_bought_cost_sol,
59
+ stats_7d_total_sold_income_sol,
60
+ stats_7d_total_fee,
61
+ stats_7d_winrate,
62
+ stats_7d_tokens_traded,
63
+
64
+ // 30 Days is to useless in the context
65
+ ],
66
+
67
+ // Data from the 'wallet_socials' table (Social media and profile info)
68
+ wallet_socials_row: [
69
+ has_pf_profile,
70
+ has_twitter,
71
+ has_telegram,
72
+ is_exchange_wallet,
73
+ username,
74
+ ],
75
+ // Data from the 'wallet_holdings' table (Token-level statistics for held tokens)
76
+ wallet_holdings_pool: [
77
+ <TokenVibeEmbedding>,
78
+ holding_time, // How much he held the token (We check only tokens that currently is holding, or recently traded)
79
+
80
+ balance_pct_to_supply, // Current quantity of the token held
81
+
82
+ // History (Amounts & Costs)
83
+ history_bought_amount_sol, // Total amount of token bought
84
+ bought_amount_sol_pct_to_native_balance // Is he traded a lot of his wallet size
85
+
86
+ // History (Counts)
87
+ history_total_buys, // Total number of buy transactions
88
+ history_total_sells, // Total number of sell transactions
89
+
90
+ // Profit and Loss
91
+ realized_profit_pnl, // Realized P&L as a percentage
92
+ realized_profit_sol,
93
+
94
+ // Transfers (Non-trade movements)
95
+ history_transfer_in,
96
+ history_transfer_out,
97
+
98
+ avarage_trade_gap_seconds,
99
+ total_priority_fees, // Total tips + Priority Fees
100
+ ]
101
+ ]
102
+
103
+ <TokenVibeEmbedding> # Multimodal embedding of a token's identity
104
+ <TokenVibeEmbedding> = [<TokenAddressEmbedding>, <NameEmbedding>, <SymbolEmbedding>, <ImageEmbedding>, protocol_id]
105
+
106
+ <TextEmbedding> # Text embedding MultiModal processor.
107
+ <MediaEmbedding> # Multimodal VIT encoder.
108
+
109
+ # -----------------------------------------
110
+ # 1. TradeEncoder
111
+ # -----------------------------------------
112
+
113
+ # Captures large-size trades from any wallet.
114
+ [timestamp, 'LargeTrade', relative_ts, <WalletEmbedding>, trade_direction, sol_amount, dex_platform_id, priority_fee, mev_protection, token_amount_pct_of_holding, quote_amount_pct_of_holding, slippage, price_impact, success, is_bundle, total_usd]
115
+
116
+ # Captures the high-signal "Dev Sold or Bought" event.
117
+ [timestamp, 'Deployer_Trade', relative_ts, <CreatorWalletEmbedding>, trade_direction, sol_amount, dex_platform_id, priority_fee, mev_protection, token_amount_pct_of_holding, quote_amount_pct_of_holding, slippage, price_impact, success, is_bundle, total_usd]
118
+
119
+ # Captures *all* trades from pre-defined high-P&L/win-rate, kol and known wallets.
120
+ [timestamp, 'SmartWallet_Trade', relative_ts, <TraderWalletEmbedding>, trade_direction, sol_amount, dex_platform_id, priority_fee, mev_protection, token_amount_pct_of_holding, quote_amount_pct_of_holding, slippage, price_impact, success, is_bundle, total_usd]
121
+
122
+ # Raw trades. Loaded in H/B/H Prefix (first ~10k) and Suffix (last ~5k).
123
+ [timestamp, 'Trade', relative_ts, <TraderWalletEmbedding>, trade_direction, sol_amount, dex_platform_id, priority_fee, mev_protection, token_amount_pct_of_holding, quote_amount_pct_of_holding, slippage, price_impact, success, is_bundle, total_usd]
124
+
125
+ # -----------------------------------------
126
+ # 2. TransferEncoder
127
+ # -----------------------------------------
128
+
129
+ # Raw transfers. Loaded in H/B/H Prefix (all in first ~10k trade window) and Suffix (all in last ~5k trade window).
130
+ [timestamp, 'Transfer', relative_ts, <SourceWalletEmbedding>, <DestinationWalletEmbedding>, token_amount, transfer_pct_of_total_supply, transfer_pct_of_holding, priority_fee]
131
+
132
+ # Captures scarce, large transfers *after* the initial launch window.
133
+ [timestamp, 'LargeTransfer', relative_ts, <FromWalletEmbedding>, <ToWalletEmbedding>, token_amount, transfer_pct_of_total_supply, transfer_pct_of_holding, priority_fee]
134
+
135
+ # -----------------------------------------
136
+ # 3. LifecycleEncoder
137
+ # -----------------------------------------
138
+
139
+ # The T0 event.
140
+ [timestamp, 'Mint', 0, <CreatorWalletEmbedding>, <TokenVibeEmbedding>]
141
+
142
+ # -----------------------------------------
143
+ # 3. PoolEncoder
144
+ # -----------------------------------------
145
+
146
+ # Signals migration from launchpad to a real pool.
147
+ [timestamp, 'PoolCreated', relative_ts, <ProviderWalletEmbedding>, protocol_id, <QuoteTokenVibeEmbedding>, base_amount, quote_amount, quote_pct_to_main_pool_balance, base_pct_to_main_pool_balance]
148
+
149
+ # Signals LP addition or removal.
150
+ [timestamp, 'LiquidityChange', relative_ts, <ProviderWalletEmbedding>, <QuoteTokenVibeEmbedding>, change_type_id, quote_amount, quote_pct_to_current_pool_balance]
151
+
152
+ # Signals creator/dev taking platform fees.
153
+ [timestamp, 'FeeCollected', relative_ts, <RecipientWalletEmbedding>, sol_amount, token_amount]
154
+
155
+
156
+ # -----------------------------------------
157
+ # SupplyEncoder
158
+ # -----------------------------------------
159
+
160
+ # Signals a supply reduction.
161
+ [timestamp, 'TokenBurn', relative_ts, <BurnerWalletEmbedding>, amount_pct_of_total_supply, amount_tokens_burned]
162
+
163
+ # Signals locked supply, e.g., for team/marketing.
164
+ [timestamp, 'SupplyLock', relative_ts, <LockerWalletEmbedding>, amount_pct_of_total_supply, lock_duration]
165
+
166
+ # -----------------------------------------
167
+ # ChartEncoder
168
+ # -----------------------------------------
169
+
170
+ # (The "Sliding Window") This is the new chart event.
171
+ [timestamp, 'Chart_Segment', relative_ts, OHLC_segment, chart_interval_id]
172
+
173
+ # -----------------------------------------
174
+ # PulseEncoder
175
+ # -----------------------------------------
176
+
177
+ # It is a low-frequency event (Dynamic Interval: 5min, 15min, or 1hr based on token age).
178
+ [timestamp, 'OnChain_Snapshot', relative_ts, total_holders, smart_traders, kols, holder_growth_rate, top_10_holder_pct, sniper_holding_pct, rat_wallets_holding_pct, bundle_holding_pct, current_market_cap, liquidity, volume, buy_count, sell_count, total_txns, global_fees_paid]
179
+
180
+ # -----------------------------------------
181
+ # HoldersListEncoder
182
+ # -----------------------------------------
183
+
184
+ <HolderDistributionEmbedding> # Transformer-based embedding of the top holders (WalletEmbeddings + Pct).
185
+
186
+ # Token-specific holder analysis.
187
+ [timestamp, 'HolderSnapshot', relative_ts, <HolderDistributionEmbedding>]
188
+
189
+
190
+ # -----------------------------------------
191
+ # ChainSnapshotEncoder
192
+ # -----------------------------------------
193
+
194
+ # Broad chain-level market conditions.
195
+ [timestamp, 'ChainSnapshot', relative_ts, native_token_price_usd, gas_fee]
196
+
197
+ # Launchpad market regime (using absolute, log-normalized values).
198
+ [timestamp, 'Lighthouse_Snapshot', relative_ts, protocol_id, timeframe_id, total_volume, total_transactions, total_traders, total_tokens_created, total_migrations]
199
+
200
+ # -----------------------------------------
201
+ # TokenTrendingListEncoder
202
+ # -----------------------------------------
203
+
204
+ # Fires *per token* on a trending list. The high-attention "meta" signal.
205
+ [timestamp, 'TrendingToken', relative_ts, <TokenVibeEmbedding_of_trending_token>, list_source_id, timeframe_id, rank]
206
+
207
+ # Fires *per token* on the boosted list.
208
+ [timestamp, 'BoostedToken', relative_ts, <TokenVibeEmbedding_of_boosted_token>, total_boost_amount, rank]
209
+
210
+ # -----------------------------------------
211
+ # LaunchpadTheadEncoder
212
+ # -----------------------------------------
213
+
214
+ # On-platform social signal (Pump.fun comments).
215
+ [timestamp, 'PumpReply', relative_ts, <UserWalletEmbedding>, <ReplyTextEmbedding>]
216
+
217
+ # -----------------------------------------
218
+ # CTEncoder
219
+ # -----------------------------------------
220
+
221
+ # Off-platform social signal (Twitter).
222
+ [timestamp, 'XPost', relative_ts, <AuthorWalletEmbedding>, <PostTextEmbedding>, <MediaEmbedding>]
223
+ [timestamp, 'XRetweet', relative_ts, <RetweeterWalletEmbedding>, <OriginalAuthorWalletEmbedding>, <OriginalPostTextEmbedding>, <OriginalPostMediaEmbedding>]
224
+ [timestamp, 'XReply', relative_ts, <AuthorWalletEmbedding>, <PostTextEmbedding>, <MediaEmbedding>, <MainTweetEmbedding>]
225
+ [timestamp, 'XQuoteTweet', relative_ts, <QuoterWalletEmbedding>, <QuoterTextEmbedding>, <OriginalAuthorWalletEmbedding>, <OriginalPostTextEmbedding>, <OriginalPostMediaEmbedding>]
226
+
227
+ # -----------------------------------------
228
+ # GlobalTrendingEncoder
229
+ # -----------------------------------------
230
+
231
+ # Broader cultural trend signal (TikTok).
232
+ [timestamp, 'TikTok_Trending_Hashtag', relative_ts, <HashtagNameEmbedding>, rank]
233
+
234
+ # Broader cultural trend signal (Twitter).
235
+ [timestamp, 'XTrending_Hashtag', relative_ts, <HashtagNameEmbedding>, rank]
236
+
237
+ # -----------------------------------------
238
+ # TrackerEncoder
239
+ # -----------------------------------------
240
+
241
+ # Retail marketing signal (Paid groups).
242
+ [timestamp, 'AlphaGroup_Call', relative_ts, group_id]
243
+
244
+ [timestamp, 'Call_Channel', relative_ts, channel_id]
245
+
246
+ # High-impact catalyst event.
247
+ [timestamp, 'CexListing', relative_ts, exchange_id]
248
+
249
+ # High-impact catalyst event.
250
+ [timestamp, 'Migrated', relative_ts, protocol_id]
251
+
252
+ # -----------------------------------------
253
+ # Dex Encoder
254
+ # -----------------------------------------
255
+
256
+ [timestamp, 'DexBoost_Paid', relative_ts, amount, total_amount_on_token]
257
+
258
+ [timestamp, 'DexProfile_Updated', relative_ts, has_changed_website_flag, has_changed_twitter_flag, has_changed_telegram_flag, has_changed_description_flag, <WebsiteEmbedding>, <TwitterLinkEmbedding>, <NewDescriptionEmbeeded>]
259
+
260
+ ### **Global Context Injection**
261
+
262
+ <PRELAUNCH> <LAUNCH> <Middle> <RECENT>
263
+
264
+ ### **Token Role Embedding**
265
+
266
+ <TokenVibeEmbedding_of_Token_A> + Subject_Token_Role
267
+
268
+ <TokenVibeEmbedding_of_Token_B> + Trending_Token_Role
269
+
270
+ <QuoteTokenVibeEmbedding_of_USDC> + Quote_Token_Role
271
+
272
+
273
+ # **Links**
274
+
275
+ ### `TransferLink`
276
+
277
+ ```
278
+ ['signature', 'source', 'destination', 'mint', 'timestamp']
279
+ ```
280
+
281
+ -----
282
+
283
+ ### `BundleTradeLink`
284
+
285
+ ```
286
+ ['signatures', 'wallet_a', 'wallet_b', 'mint', 'slot', 'timestamp']
287
+ ```
288
+
289
+ -----
290
+
291
+ ### `CopiedTradeLink`
292
+
293
+ ```
294
+ ['leader_buy_sig', 'leader_sell_sig', 'follower_buy_sig', 'follower_sell_sig', 'follower', 'leader', 'mint', 'time_gap_on_buy_sec', 'time_gap_on_sell_sec', 'leader_pnl', 'follower_pnl', 'leader_buy_total', 'leader_sell_total', 'follower_buy_total', 'follower_sell_total', 'follower_buy_slippage', 'follower_sell_slippage']
295
+ ```
296
+
297
+ -----
298
+
299
+ ### `CoordinatedActivityLink`
300
+
301
+ ```
302
+ ['leader_first_sig', 'leader_second_sig', 'follower_first_sig', 'follower_second_sig', 'follower', 'leader', 'mint', 'time_gap_on_first_sec', 'time_gap_on_second_sec']
303
+ ```
304
+
305
+ -----
306
+
307
+ ### `MintedLink`
308
+
309
+ ```
310
+ ['signature', 'timestamp', 'buy_amount']
311
+ ```
312
+
313
+ -----
314
+
315
+ ### `SnipedLink`
316
+
317
+ ```
318
+ ['signature', 'rank', 'sniped_amount']
319
+ ```
320
+
321
+ -----
322
+
323
+ ### `LockedSupplyLink`
324
+
325
+ ```
326
+ ['signature', 'amount', 'unlock_timestamp']
327
+ ```
328
+
329
+ -----
330
+
331
+ ### `BurnedLink`
332
+
333
+ ```
334
+ ['signature', 'amount', 'timestamp']
335
+ ```
336
+
337
+ -----
338
+
339
+ ### `ProvidedLiquidityLink`
340
+
341
+ ```
342
+ ['signature', 'wallet', 'token', 'pool_address', 'amount_base', 'amount_quote', 'timestamp']
343
+ ```
344
+
345
+ -----
346
+
347
+ ### `WhaleOfLink`
348
+
349
+ ```
350
+ ['wallet', 'token', 'holding_pct_at_creation', 'ath_usd_at_creation']
351
+ ```
352
+
353
+ -----
354
+
355
+ ### `TopTraderOfLink`
356
+
357
+ ```
358
+ ['wallet', 'token', 'pnl_at_creation', 'ath_usd_at_creation']
359
+ ```
360
+
361
+
362
+
363
+
364
+ /////
365
+
366
+ def __gettestitem__(self, idx: int) -> Dict[str, Any]:
367
+ """
368
+ Generates a single complex data item, structured for the MemecoinCollator.
369
+ NOTE: This currently returns the same mock data regardless of `idx`.
370
+ """
371
+ # --- 1. Setup Pooler and Define Raw Data ---
372
+ pooler = EmbeddingPooler()
373
+
374
+ # --- 5. Create Mock Raw Batch Data (FIXED) ---
375
+ print("Creating mock raw batch...")
376
+
377
+ # (Wallet profiles, socials, holdings definitions are unchanged)
378
+ profile1 = {
379
+ 'wallet_address': 'addrW1', 'age': 1.5e7, 'balance': 10.5,
380
+ 'deployed_tokens_count': 2, 'deployed_tokens_migrated_pct': 0.5, 'deployed_tokens_avg_lifetime_sec': 36000.0, 'deployed_tokens_avg_peak_mc_usd': 100000.0, 'deployed_tokens_median_peak_mc_usd': 50000.0,
381
+ 'transfers_in_count': 10, 'transfers_out_count': 5, 'spl_transfers_in_count': 20, 'spl_transfers_out_count': 15,
382
+ 'total_buys_count': 50, 'total_sells_count': 40, 'total_winrate': 0.6,
383
+ 'stats_1d_realized_profit_sol': 1.2, 'stats_1d_realized_profit_pnl': 0.1, 'stats_1d_buy_count': 5, 'stats_1d_sell_count': 3, 'stats_1d_transfer_in_count': 2, 'stats_1d_transfer_out_count': 1, 'stats_1d_avg_holding_period': 3600, 'stats_1d_total_bought_cost_sol': 10.0, 'stats_1d_total_sold_income_sol': 11.2, 'stats_1d_total_fee': 0.1, 'stats_1d_winrate': 0.7, 'stats_1d_tokens_traded': 4,
384
+ 'stats_7d_realized_profit_sol': 5.0, 'stats_7d_realized_profit_pnl': 0.2, 'stats_7d_buy_count': 20, 'stats_7d_sell_count': 15, 'stats_7d_transfer_in_count': 8, 'stats_7d_transfer_out_count': 4, 'stats_7d_avg_holding_period': 7200, 'stats_7d_total_bought_cost_sol': 40.0, 'stats_7d_total_sold_income_sol': 45.0, 'stats_7d_total_fee': 0.5, 'stats_7d_winrate': 0.65, 'stats_7d_tokens_traded': 10,
385
+ }
386
+ social1 = {'has_pf_profile': True, 'has_twitter': True, 'has_telegram': False, 'is_exchange_wallet': False, 'username': 'trader_one'}
387
+ holdings1 = [
388
+ {'mint_address': 'tknA', 'holding_time': 3600.0, 'realized_profit_sol': 5.2, 'total_priority_fees': 0.05, 'balance_pct_to_supply': 0.01, 'history_bought_amount_sol': 10, 'bought_amount_sol_pct_to_native_balance': 0.5, 'history_total_buys': 5, 'history_total_sells': 2, 'realized_profit_pnl': 0.52, 'history_transfer_in': 1, 'history_transfer_out': 0, 'avarage_trade_gap_seconds': 300},
389
+ ]
390
+ profile2 = {
391
+ 'wallet_address': 'addrW2', 'age': 1e6, 'balance': 1.0,
392
+ 'deployed_tokens_count': 0, 'deployed_tokens_migrated_pct': 0.0, 'deployed_tokens_avg_lifetime_sec': 0.0, 'deployed_tokens_avg_peak_mc_usd': 0.0, 'deployed_tokens_median_peak_mc_usd': 0.0,
393
+ 'transfers_in_count': 1, 'transfers_out_count': 0, 'spl_transfers_in_count': 0, 'spl_transfers_out_count': 0,
394
+ 'total_buys_count': 0, 'total_sells_count': 0, 'total_winrate': 0.0,
395
+ 'stats_1d_realized_profit_sol': 0.0, 'stats_1d_realized_profit_pnl': 0.0, 'stats_1d_buy_count': 0, 'stats_1d_sell_count': 0, 'stats_1d_transfer_in_count': 0, 'stats_1d_transfer_out_count': 0, 'stats_1d_avg_holding_period': 0, 'stats_1d_total_bought_cost_sol': 0.0, 'stats_1d_total_sold_income_sol': 0.0, 'stats_1d_total_fee': 0.0, 'stats_1d_winrate': 0.0, 'stats_1d_tokens_traded': 0,
396
+ 'stats_7d_realized_profit_sol': 0.0, 'stats_7d_realized_profit_pnl': 0.0, 'stats_7d_buy_count': 0, 'stats_7d_sell_count': 0, 'stats_7d_transfer_in_count': 0, 'stats_7d_transfer_out_count': 0, 'stats_7d_avg_holding_period': 0, 'stats_7d_total_bought_cost_sol': 0.0, 'stats_7d_total_sold_income_sol': 0.0, 'stats_7d_total_fee': 0.0, 'stats_7d_winrate': 0.0, 'stats_7d_tokens_traded': 0,
397
+ }
398
+ social2 = {'has_pf_profile': False, 'has_twitter': False, 'has_telegram': False, 'is_exchange_wallet': True, 'username': 'cex_wallet'}
399
+ holdings2 = []
400
+
401
+
402
+ # Define raw data and get their indices
403
+ tokenA_data = {
404
+ 'address_emb_idx': pooler.get_idx('tknA'),
405
+ 'name_emb_idx': pooler.get_idx('Token A'),
406
+ 'symbol_emb_idx': pooler.get_idx('TKA'),
407
+ 'image_emb_idx': pooler.get_idx(Image.new('RGB',(256,256), color='blue')),
408
+ 'protocol': 1
409
+ }
410
+ # Add wallet usernames to the pool
411
+ wallet1_user_idx = pooler.get_idx(social1['username'])
412
+ wallet2_user_idx = pooler.get_idx(social2['username'])
413
+ social1['username_emb_idx'] = wallet1_user_idx
414
+ social2['username_emb_idx'] = wallet2_user_idx
415
+ # --- NEW: Add a third wallet for social tests ---
416
+ social3 = {'has_pf_profile': False, 'has_twitter': True, 'has_telegram': True, 'is_exchange_wallet': False, 'username': 'social_butterfly'}
417
+ wallet3_user_idx = pooler.get_idx(social3['username'])
418
+ social3['username_emb_idx'] = wallet3_user_idx
419
+
420
+ # Create the final pre-computed data structures
421
+ tokenB_data = {
422
+ 'address_emb_idx': pooler.get_idx('tknA'),
423
+ 'name_emb_idx': pooler.get_idx('Token A'),
424
+ 'symbol_emb_idx': pooler.get_idx('TKA'),
425
+ 'image_emb_idx': pooler.get_idx(Image.new('RGB',(256,256), color='blue')),
426
+ 'protocol': 1
427
+ }
428
+
429
+ tokenC_data = {
430
+ 'address_emb_idx': pooler.get_idx('tknA'),
431
+ 'name_emb_idx': pooler.get_idx('Token A'),
432
+ 'symbol_emb_idx': pooler.get_idx('TKA'),
433
+ 'image_emb_idx': pooler.get_idx(Image.new('RGB',(256,256), color='blue')),
434
+ 'protocol': 1
435
+ }
436
+
437
+ tokenD_data = {
438
+ 'address_emb_idx': pooler.get_idx('tknA'),
439
+ 'name_emb_idx': pooler.get_idx('Token A'),
440
+ 'symbol_emb_idx': pooler.get_idx('TKA'),
441
+ 'image_emb_idx': pooler.get_idx(Image.new('RGB',(256,256), color='blue')),
442
+ 'protocol': 1
443
+ }
444
+
445
+ item = {
446
+ 'event_sequence': [
447
+ {'event_type': 'XPost', # NEW
448
+ 'timestamp': 1729711350,
449
+ 'relative_ts': -25,
450
+ 'wallet_address': 'addrW1', # Author
451
+ 'text_emb_idx': pooler.get_idx('This is the main tweet about $TKA'),
452
+ 'media_emb_idx': pooler.get_idx(Image.new('RGB', (100,100), color='cyan'))
453
+ },
454
+ {'event_type': 'XReply', # NEW
455
+ 'timestamp': 1729711360,
456
+ 'relative_ts': -35,
457
+ 'wallet_address': 'addrW2', # Replier
458
+ 'text_emb_idx': pooler.get_idx('This is a reply to the main tweet'),
459
+ 'media_emb_idx': pooler.get_idx(None), # No media in reply
460
+ 'main_tweet_text_emb_idx': pooler.get_idx('This is the main tweet about $TKA')
461
+ },
462
+ {'event_type': 'XRetweet', # NEW
463
+ 'timestamp': 1729711370,
464
+ 'relative_ts': -40,
465
+ 'wallet_address': 'addrW3', # The retweeter
466
+ 'original_author_wallet_address': 'addrW1', # The original author
467
+ 'original_post_text_emb_idx': pooler.get_idx('This is the main tweet about $TKA'),
468
+ 'original_post_media_emb_idx': pooler.get_idx(Image.new('RGB', (100,100), color='cyan'))
469
+ },
470
+ # --- CORRECTED: Test a pre-launch event with negative relative_ts ---
471
+ {'event_type': 'Transfer',
472
+ 'timestamp': 1729711180,
473
+ 'relative_ts': -10, # Negative relative_ts indicates pre-launch
474
+ 'wallet_address': 'addrW2',
475
+ 'destination_wallet_address': 'addrW1',
476
+ 'token_address': 'tknA',
477
+ 'token_amount': 1000.0, 'transfer_pct_of_total_supply': 0.0, 'transfer_pct_of_holding': 0.0, 'priority_fee': 0.0
478
+ },
479
+ {'event_type': 'Mint', 'timestamp': 1729711190, 'relative_ts': 0, 'wallet_address': 'addrW1', 'token_address': 'tknA'},
480
+ {'event_type': 'Chart_Segment', 'timestamp': 1729711200, 'relative_ts': 60, 'opens': [1.0]*OHLC_SEQ_LEN, 'closes': [1.1]*OHLC_SEQ_LEN, 'i': '1s'}, # This is high-def (segment 0) by default
481
+ {'event_type': 'Chart_Segment', 'timestamp': 1729711260, 'relative_ts': 120, 'opens': [1.0]*OHLC_SEQ_LEN, 'closes': [1.1]*OHLC_SEQ_LEN, 'i': '1s'}, # You can mark this as blurry
482
+ {'event_type': 'Transfer',
483
+ 'timestamp': 1729711210,
484
+ 'relative_ts': 20,
485
+ 'wallet_address': 'addrW1', # Source
486
+ 'destination_wallet_address': 'addrW2', # Destination
487
+ 'token_address': 'tknA', # Need token for context? (Optional, depends on design)
488
+ 'token_amount': 500.0,
489
+ 'transfer_pct_of_total_supply': 0.005,
490
+ 'transfer_pct_of_holding': 0.1,
491
+ 'priority_fee': 0.0001
492
+ },
493
+ {'event_type': 'Trade',
494
+ 'timestamp': 1729711220,
495
+ 'relative_ts': 30,
496
+ 'wallet_address': 'addrW1',
497
+ 'token_address': 'tknA',
498
+ 'trade_direction': 0,
499
+ 'sol_amount': 0.5,
500
+ # --- FIXED: Pass the integer ID directly ---
501
+ 'dex_platform_id': vocab.DEX_TO_ID['Axiom'],
502
+ 'priority_fee': 0.0002,
503
+ 'mev_protection': False,
504
+ 'token_amount_pct_of_holding': 0.05, 'quote_amount_pct_of_holding': 0.02,
505
+ 'slippage': 0.01, 'price_impact': 0.005, 'success': True, 'is_bundle': False, 'total_usd': 75.0
506
+ },
507
+ {'event_type': 'Deployer_Trade', # NEW: Testing a trade variant
508
+ 'timestamp': 1729711230,
509
+ 'relative_ts': 40,
510
+ 'wallet_address': 'addrW1', # The creator wallet
511
+ 'token_address': 'tknA',
512
+ 'trade_direction': 1, 'sol_amount': 0.2,
513
+ # --- FIXED: Pass the integer ID directly ---
514
+ 'dex_platform_id': vocab.DEX_TO_ID['Trojan'],
515
+ 'priority_fee': 0.0005,
516
+ 'mev_protection': True,
517
+ 'token_amount_pct_of_holding': 0.1, 'quote_amount_pct_of_holding': 0.0,
518
+ 'slippage': 0.02, 'price_impact': 0.01, 'success': True, 'is_bundle': False, 'total_usd': 30.0
519
+ },
520
+ {'event_type': 'SmartWallet_Trade', # NEW
521
+ 'timestamp': 1729711240,
522
+ 'relative_ts': 50,
523
+ 'wallet_address': 'addrW1', # A known smart wallet
524
+ 'token_address': 'tknA',
525
+ 'trade_direction': 0, 'sol_amount': 1.5,
526
+ # --- FIXED: Pass the integer ID directly ---
527
+ 'dex_platform_id': vocab.DEX_TO_ID['Axiom'],
528
+ 'priority_fee': 0.001,
529
+ 'mev_protection': True,
530
+ 'token_amount_pct_of_holding': 0.2, 'quote_amount_pct_of_holding': 0.1,
531
+ 'slippage': 0.01, 'price_impact': 0.008, 'success': True, 'is_bundle': False, 'total_usd': 225.0
532
+ },
533
+ {'event_type': 'LargeTrade', # NEW
534
+ 'timestamp': 1729711250,
535
+ 'relative_ts': 60,
536
+ 'wallet_address': 'addrW2', # Some other wallet
537
+ 'token_address': 'tknA',
538
+ 'trade_direction': 0, 'sol_amount': 10.0,
539
+ # --- FIXED: Pass the integer ID directly ---
540
+ 'dex_platform_id': vocab.DEX_TO_ID['OXK'],
541
+ 'priority_fee': 0.002,
542
+ 'mev_protection': False,
543
+ 'token_amount_pct_of_holding': 0.8, 'quote_amount_pct_of_holding': 0.5,
544
+ 'slippage': 0.03, 'price_impact': 0.05, 'success': True, 'is_bundle': False, 'total_usd': 1500.0
545
+ },
546
+ {'event_type': 'Chart_Segment', 'timestamp': 1729711260, 'relative_ts': 70, 'opens': [1.0]*OHLC_SEQ_LEN, 'closes': [1.1]*OHLC_SEQ_LEN, 'i': '1s'},
547
+ {'event_type': 'PoolCreated', # NEW
548
+ 'timestamp': 1729711270,
549
+ 'relative_ts': 80,
550
+ 'wallet_address': 'addrW1',
551
+ 'protocol_id': vocab.PROTOCOL_TO_ID['Raydium CPMM'],
552
+ 'quote_token_address': 'tknB',
553
+ 'base_amount': 1000000.0,
554
+ 'quote_amount': 10.0
555
+ },
556
+ {'event_type': 'LiquidityChange', # NEW
557
+ 'timestamp': 1729711280,
558
+ 'relative_ts': 90,
559
+ 'wallet_address': 'addrW2',
560
+ 'quote_token_address': 'tknB',
561
+ 'change_type_id': 0, # 0 for 'add'
562
+ 'quote_amount': 2.0
563
+ },
564
+ {'event_type': 'FeeCollected', # NEW
565
+ 'timestamp': 1729711290,
566
+ 'relative_ts': 100,
567
+ 'wallet_address': 'addrW1', # The recipient (e.g., dev wallet)
568
+ 'sol_amount': 0.1
569
+ },
570
+ {'event_type': 'TokenBurn', # NEW
571
+ 'timestamp': 1729711300,
572
+ 'relative_ts': 110,
573
+ 'wallet_address': 'addrW2', # The burner wallet
574
+ 'amount_pct_of_total_supply': 0.01, # 1% of supply
575
+ 'amount_tokens_burned': 10000000.0
576
+ },
577
+ {'event_type': 'SupplyLock', # NEW
578
+ 'timestamp': 1729711310,
579
+ 'relative_ts': 120,
580
+ 'wallet_address': 'addrW1', # The locker wallet
581
+ 'amount_pct_of_total_supply': 0.10, # 10% of supply
582
+ 'lock_duration': 2592000 # 30 days in seconds
583
+ },
584
+ {'event_type': 'HolderSnapshot', # NEW
585
+ 'timestamp': 1729711320,
586
+ 'relative_ts': 130,
587
+ # This is a pointer to the pre-computed embedding
588
+ # In a real system, this would be the index of the embedding
589
+ 'holders': [ # Raw holder data
590
+ {'wallet': 'addrW1', 'holding_pct': 0.15},
591
+ {'wallet': 'addrW2', 'holding_pct': 0.05},
592
+ # Add more mock holders if needed
593
+ ]
594
+ },
595
+ {'event_type': 'OnChain_Snapshot', # NEW
596
+ 'timestamp': 1729711320,
597
+ 'relative_ts': 130,
598
+ 'total_holders': 500,
599
+ 'smart_traders': 25,
600
+ 'kols': 3,
601
+ 'holder_growth_rate': 0.15,
602
+ 'top_10_holder_pct': 0.22,
603
+ 'sniper_holding_pct': 0.05,
604
+ 'rat_wallets_holding_pct': 0.02,
605
+ 'bundle_holding_pct': 0.01,
606
+ 'current_market_cap': 150000.0,
607
+ 'volume': 50000.0,
608
+ 'buy_count': 120,
609
+ 'sell_count': 80,
610
+ 'total_txns': 200,
611
+ 'global_fees_paid': 1.5
612
+ },
613
+ {'event_type': 'TrendingToken', # NEW
614
+ 'timestamp': 1729711330,
615
+ 'relative_ts': 140,
616
+ 'token_address': 'tknC', # The token that is trending
617
+ 'list_source_id': vocab.TRENDING_LIST_SOURCE_TO_ID['Phantom'],
618
+ 'timeframe_id': vocab.TRENDING_LIST_TIMEFRAME_TO_ID['1h'],
619
+ 'rank': 3
620
+ },
621
+ {'event_type': 'BoostedToken', # NEW
622
+ 'timestamp': 1729711340,
623
+ 'relative_ts': 150,
624
+ 'token_address': 'tknD', # The token that is boosted
625
+ 'total_boost_amount': 5000.0,
626
+ 'rank': 1
627
+ },
628
+ {'event_type': 'XQuoteTweet', # NEW
629
+ 'timestamp': 1729711380,
630
+ 'relative_ts': 190,
631
+ 'wallet_address': 'addrW3', # The quoter
632
+ 'quoter_text_emb_idx': pooler.get_idx('Wow, look at this! $TKA'),
633
+ 'original_author_wallet_address': 'addrW1', # The original author
634
+ 'original_post_text_emb_idx': pooler.get_idx('This is the main tweet about $TKA'),
635
+ 'original_post_media_emb_idx': pooler.get_idx(Image.new('RGB', (100,100), color='cyan'))
636
+ },
637
+ # --- NEW: Add special context tokens ---
638
+ {'event_type': 'MIDDLE', 'timestamp': 1729711500, 'relative_ts': 195},
639
+ {'event_type': 'PumpReply', # NEW
640
+ 'timestamp': 1729711390,
641
+ 'relative_ts': 200,
642
+ 'wallet_address': 'addrW2', # The user who replied
643
+ 'reply_text_emb_idx': pooler.get_idx('to the moon!')
644
+ },
645
+ {'event_type': 'DexBoost_Paid', # NEW
646
+ 'timestamp': 1729711400,
647
+ 'relative_ts': 210,
648
+ 'amount': 5.0, # e.g., 5 Boost
649
+ 'total_amount_on_token': 25.0 # 25 Boost Points
650
+ },
651
+ {'event_type': 'DexProfile_Updated', # NEW
652
+ 'timestamp': 1729711410,
653
+ 'relative_ts': 220,
654
+ 'has_changed_website_flag': True,
655
+ 'has_changed_twitter_flag': False,
656
+ 'has_changed_telegram_flag': True,
657
+ 'has_changed_description_flag': True,
658
+ # Pre-computed text embeddings
659
+ 'website_emb_idx': pooler.get_idx('new-token-website.com'),
660
+ 'twitter_link_emb_idx': pooler.get_idx('old_handle'), # No change, so old link
661
+ 'telegram_link_emb_idx': pooler.get_idx('new_tg_group'),
662
+ 'description_emb_idx': pooler.get_idx('This is the new and improved token description.')
663
+ },
664
+ {'event_type': 'AlphaGroup_Call', # NEW
665
+ 'timestamp': 1729711420,
666
+ 'relative_ts': 230,
667
+ 'group_id': vocab.ALPHA_GROUPS_TO_ID['Potion']
668
+ },
669
+ {'event_type': 'Channel_Call', # NEW
670
+ 'timestamp': 1729711430,
671
+ 'relative_ts': 240,
672
+ 'channel_id': vocab.CALL_CHANNELS_TO_ID['MarcosCalls']
673
+ },
674
+ {'event_type': 'RECENT', 'timestamp': 1729711510, 'relative_ts': 245},
675
+ {'event_type': 'CexListing', # NEW
676
+ 'timestamp': 1729711440,
677
+ 'relative_ts': 250,
678
+ 'exchange_id': vocab.EXCHANGES_TO_ID['mexc']
679
+ },
680
+ {'event_type': 'TikTok_Trending_Hashtag', # NEW
681
+ 'timestamp': 1729711450,
682
+ 'relative_ts': 260,
683
+ 'hashtag_name_emb_idx': pooler.get_idx('CryptoTok'),
684
+ 'rank': 5
685
+ },
686
+ {'event_type': 'XTrending_Hashtag', # NEW
687
+ 'timestamp': 1729711460,
688
+ 'relative_ts': 270,
689
+ 'hashtag_name_emb_idx': pooler.get_idx('SolanaMemes'),
690
+ 'rank': 2
691
+ },
692
+ {'event_type': 'ChainSnapshot', # NEW
693
+ 'timestamp': 1729711470,
694
+ 'relative_ts': 280,
695
+ 'native_token_price_usd': 150.75,
696
+ 'gas_fee': 0.00015 # Example gas fee
697
+ },
698
+ {'event_type': 'Lighthouse_Snapshot', # NEW
699
+ 'timestamp': 1729711480,
700
+ 'relative_ts': 290,
701
+ 'protocol_id': vocab.PROTOCOL_TO_ID['Pump V1'],
702
+ 'timeframe_id': vocab.LIGHTHOUSE_TIMEFRAME_TO_ID['1h'],
703
+ 'total_volume': 1.2e6,
704
+ 'total_transactions': 5000,
705
+ 'total_traders': 1200,
706
+ 'total_tokens_created': 85,
707
+ 'total_migrations': 70
708
+ },
709
+ {'event_type': 'Migrated', # NEW
710
+ 'timestamp': 1729711490,
711
+ 'relative_ts': 300,
712
+ 'protocol_id': vocab.PROTOCOL_TO_ID['Raydium CPMM']
713
+ },
714
+
715
+ ],
716
+ 'wallets': {
717
+ 'addrW1': {'profile': profile1, 'socials': social1, 'holdings': holdings1},
718
+ 'addrW2': {'profile': profile2, 'socials': social2, 'holdings': holdings2},
719
+ # --- NEW: Add wallet 3 data ---
720
+ 'addrW3': {
721
+ 'profile': {**profile2, 'wallet_address': 'addrW3'}, # Reuse profile2 but change address
722
+ 'socials': social3,
723
+ 'holdings': []
724
+ }
725
+ },
726
+ 'tokens': {
727
+ 'tknA': tokenA_data, # Main token
728
+ 'tknB': tokenB_data, # Quote token
729
+ 'tknC': tokenC_data, # Trending token
730
+ 'tknD': tokenD_data # Boosted token
731
+ },
732
+ # --- NEW: The pre-computed embedding pool is generated after collecting all items
733
+ 'embedding_pooler': pooler, # Pass the pooler to generate the tensor later
734
+
735
+ # --- NEW: Expanded graph_links to test all encoders ---
736
+ # --- FIXED: Removed useless logging fields as per user request ---
737
+ 'graph_links': {
738
+ 'TransferLink': {'links': [{'timestamp': 1729711205}], 'edges': [('addrW1', 'addrW2')]}, # Keep timestamp
739
+ 'BundleTradeLink': {'links': [{'timestamp': 1729711215}], 'edges': [('addrW1', 'addrW2')]}, # Keep timestamp
740
+ 'CopiedTradeLink': {'links': [
741
+ {'time_gap_on_buy_sec': 10, 'time_gap_on_sell_sec': 120, 'leader_pnl': 5.0, 'follower_pnl': 4.0, 'follower_buy_total': 100, 'follower_sell_total': 120}
742
+ ], 'edges': [('addrW1', 'addrW2')]},
743
+ 'CoordinatedActivityLink': {'links': [
744
+ {'time_gap_on_first_sec': 5, 'time_gap_on_second_sec': 8}
745
+ ], 'edges': [('addrW1', 'addrW2')]},
746
+ 'MintedLink': {'links': [
747
+ {'timestamp': 1729711200, 'buy_amount': 1e9}
748
+ ], 'edges': [('addrW1', 'tknA')]},
749
+ 'SnipedLink': {'links': [
750
+ {'rank': 1, 'sniped_amount': 5e8}
751
+ ], 'edges': [('addrW1', 'tknA')]},
752
+ 'LockedSupplyLink': {'links': [
753
+ {'amount': 1e10} # Only amount is needed
754
+ ], 'edges': [('addrW1', 'tknA')]},
755
+ 'BurnedLink': {'links': [
756
+ {'timestamp': 1729711300} # Only timestamp is needed
757
+ ], 'edges': [('addrW2', 'tknA')]},
758
+ 'ProvidedLiquidityLink': {'links': [
759
+ {'timestamp': 1729711250} # Only timestamp is needed
760
+ ], 'edges': [('addrW1', 'tknA')]},
761
+ 'WhaleOfLink': {'links': [
762
+ {} # Just the existence of the link is the feature
763
+ ], 'edges': [('addrW1', 'tknA')]},
764
+ 'TopTraderOfLink': {'links': [
765
+ {'pnl_at_creation': 50000.0} # Only PnL is needed
766
+ ], 'edges': [('addrW2', 'tknA')]}
767
+ },
768
+
769
+ # --- FIXED: Removed chart_segments dictionary ---
770
+ 'labels': torch.randn(self.num_outputs) if self.num_outputs > 0 else torch.zeros(0),
771
+ 'labels_mask': torch.ones(self.num_outputs) if self.num_outputs > 0 else torch.zeros(0)
772
+ }
773
+
774
+ print("Mock raw batch created.")
775
+
776
+ return item
log.log CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9ce8d085fbecbf5108090a954e61db882a7ba0e7fddf4a57223d72e8ebf7713d
3
- size 1378
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df6cd6a1404a931ba4869d7eaf6e6a564e98b0a87f04d8edf8f6189aebfdeab4
3
+ size 20694
models/vocabulary.py CHANGED
@@ -186,3 +186,14 @@ EXCHANGES = [
186
  EXCHANGES_TO_ID = {name: i for i, name in enumerate(EXCHANGES)}
187
  ID_TO_EXCHANGES = {i: name for i, name in enumerate(EXCHANGES)}
188
  NUM_EXCHANGES = len(EXCHANGES)
 
 
 
 
 
 
 
 
 
 
 
 
186
  EXCHANGES_TO_ID = {name: i for i, name in enumerate(EXCHANGES)}
187
  ID_TO_EXCHANGES = {i: name for i, name in enumerate(EXCHANGES)}
188
  NUM_EXCHANGES = len(EXCHANGES)
189
+
190
+ # --- NEW: Return Class Thresholds ---
191
+ # Class 0: 0 - 3x
192
+ # Class 1: 3 - 10x
193
+ # Class 2: 10 - 20x
194
+ # Class 3: 20 - 100x
195
+ # Class 4: 100 - 10,000x
196
+ RETURN_THRESHOLDS = [0, 3, 10, 20, 100, 10000]
197
+
198
+ # Class 5: Manipulated (High return but suspicious metrics)
199
+ MANIPULATED_CLASS_ID = 5
pre_cache.sh CHANGED
@@ -3,6 +3,7 @@
3
 
4
  echo "Starting dataset caching..."
5
  python3 scripts/cache_dataset.py \
6
- --ohlc_stats_path "/workspace/apollo/data/ohlc_stats.npz"
 
7
 
8
  echo "Done!"
 
3
 
4
  echo "Starting dataset caching..."
5
  python3 scripts/cache_dataset.py \
6
+ --ohlc_stats_path "/workspace/apollo/data/ohlc_stats.npz" \
7
+ --max_samples 500
8
 
9
  echo "Done!"
scripts/analyze_distribution.py CHANGED
@@ -2,117 +2,536 @@
2
  import os
3
  import sys
4
  import datetime
5
- from dotenv import load_dotenv
6
  from clickhouse_driver import Client as ClickHouseClient
7
 
8
  # Add parent to path
9
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10
 
11
- load_dotenv()
 
12
 
13
  CLICKHOUSE_HOST = os.getenv("CLICKHOUSE_HOST", "localhost")
14
  CLICKHOUSE_PORT = int(os.getenv("CLICKHOUSE_PORT", 9000))
15
- CLICKHOUSE_USER = os.getenv("CLICKHOUSE_USER", "default")
 
16
  CLICKHOUSE_PASSWORD = os.getenv("CLICKHOUSE_PASSWORD", "")
17
  CLICKHOUSE_DATABASE = os.getenv("CLICKHOUSE_DATABASE", "default")
18
 
19
- def analyze():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  try:
21
- client = ClickHouseClient(
22
- host=CLICKHOUSE_HOST,
23
- port=CLICKHOUSE_PORT,
24
- user=CLICKHOUSE_USER,
25
- password=CLICKHOUSE_PASSWORD,
26
- database=CLICKHOUSE_DATABASE
27
- )
28
 
29
- print("--- Database Stats Analysis ---")
 
 
 
 
 
30
 
31
- # 1. Total Mints
32
- total_mints = client.execute("SELECT count() FROM mints")[0][0]
33
- print(f"Total Mints: {total_mints}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- if total_mints == 0:
36
- print("No data found.")
37
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- # 2. Migrated Count (Proxy: launchpad != protocol OR check if in raydium pairs)
40
- # Assuming we can infer success or use token_metrics
41
- # Let's look at ATH Price distribution from token_metrics which is populated by the indexer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # Check coverage of token_metrics
44
- total_metrics = client.execute("SELECT count() FROM token_metrics")[0][0]
45
- print(f"Tokens with Metrics: {total_metrics} (Coverage: {total_metrics/total_mints*100:.1f}%)")
46
-
47
- # 3. ATH Price Stats
48
- # We need to know what a '5x' looks like.
49
- # Since we don't have 'opening price' easily indexed for all, let's assume standard pump.fun open price ranges
50
- # or just look at Market Cap distribution if available, or just raw ATH price.
51
- # Pump.fun launch MC is usually ~$4-5k.
52
- # 5x = $25k MC.
53
- # 10x = $50k MC (Migration).
54
 
55
- # Let's check distribution of ath_price_usd * total_supply (Approx ATH Market Cap)
56
- # We need total_supply from tokens table.
 
 
 
57
 
58
- print("\n--- ATH Market Cap Distribution (Approx) ---")
59
- query_mc_buckets = """
60
- SELECT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  case
62
- when mc < 5000 then '1. < $5k (Fail)'
63
- when mc >= 5000 AND mc < 20000 then '2. $5k - $20k (2x-4x)'
64
- when mc >= 20000 AND mc < 60000 then '3. $20k - $60k (4x-12x)'
65
- when mc >= 60000 AND mc < 150000 then '4. $60k - $150k (12x-30x)'
66
- when mc >= 150000 then '5. > $150k (Mooners)'
67
  else 'Unknown'
68
- end as bucket,
69
- count() as cnt
70
- FROM (
71
- SELECT
72
- tm.ath_price_usd * (t.total_supply / pow(10, t.decimals)) as mc
73
- FROM token_metrics tm
74
- JOIN tokens t ON tm.token_address = t.token_address
75
- )
76
- GROUP BY bucket
77
- ORDER BY bucket
78
  """
79
- rows = client.execute(query_mc_buckets)
80
- for r in rows:
81
- print(f"{r[0]}: {r[1]} tokens")
82
 
83
- # 4. Volume Distribution
84
- # Helps define "High Volume Losers" vs "Garbage"
85
- print("\n--- Volume Distribution (Total USD) ---")
86
- # Aggregating all trades is heavy, let's do a sample or use token_metrics if it has volume (it doesn't seem to have volume sum in snippet)
87
- # We'll use a subquery on trades for a subset or just a heavy query if local
88
-
89
- query_vol_buckets = """
90
- SELECT
91
  case
92
- when vol < 100 then '1. < $100 (Dead)'
93
- when vol >= 100 AND vol < 1000 then '2. $100 - $1k (Tiny)'
94
- when vol >= 1000 AND vol < 10000 then '3. $1k - $10k (Noise)'
95
- when vol >= 10000 AND vol < 100000 then '4. $10k - $100k (Active)'
96
- when vol >= 100000 then '5. > $100k (High)'
97
  else 'Unknown'
98
- end as bucket,
99
- count() as cnt
100
- FROM (
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  SELECT
102
- base_address, sum(price_usd * amount_decimal) as vol
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  FROM trades
104
  GROUP BY base_address
105
- )
106
- GROUP BY bucket
107
- ORDER BY bucket
108
  """
109
- # This might be slow on huge datasets.
110
- rows_vol = client.execute(query_vol_buckets)
111
- for r in rows_vol:
112
- print(f"{r[0]}: {r[1]} tokens")
 
 
 
 
 
 
 
 
113
 
114
- except Exception as e:
115
- print(f"Error: {e}")
116
 
117
  if __name__ == "__main__":
118
  analyze()
 
2
  import os
3
  import sys
4
  import datetime
 
5
  from clickhouse_driver import Client as ClickHouseClient
6
 
7
  # Add parent to path
8
  sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9
 
10
+ # removed dotenv
11
+ # load_dotenv()
12
 
13
  CLICKHOUSE_HOST = os.getenv("CLICKHOUSE_HOST", "localhost")
14
  CLICKHOUSE_PORT = int(os.getenv("CLICKHOUSE_PORT", 9000))
15
+ # .env shows empty user/pass, which implies 'default' user and empty password for ClickHouse
16
+ CLICKHOUSE_USER = os.getenv("CLICKHOUSE_USER", "default")
17
  CLICKHOUSE_PASSWORD = os.getenv("CLICKHOUSE_PASSWORD", "")
18
  CLICKHOUSE_DATABASE = os.getenv("CLICKHOUSE_DATABASE", "default")
19
 
20
+ def get_client():
21
+ return ClickHouseClient(
22
+ host=CLICKHOUSE_HOST,
23
+ port=CLICKHOUSE_PORT,
24
+ user=CLICKHOUSE_USER,
25
+ password=CLICKHOUSE_PASSWORD,
26
+ database=CLICKHOUSE_DATABASE
27
+ )
28
+
29
+ def print_distribution_stats(client, metric_name, subquery, bucket_case_sql):
30
+ print(f"\n -> {metric_name}")
31
+
32
+ # 1. Print Basic Stats (Mean, Quantiles)
33
+ stats_query = f"""
34
+ SELECT
35
+ avg(val),
36
+ quantiles(0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99)(val),
37
+ min(val),
38
+ max(val),
39
+ count()
40
+ FROM (
41
+ {subquery}
42
+ )
43
+ """
44
  try:
45
+ stats = client.execute(stats_query)[0]
46
+ avg_val = stats[0]
47
+ qs = stats[1]
48
+ min_val = stats[2]
49
+ max_val = stats[3]
50
+ count_val = stats[4]
 
51
 
52
+ if count_val == 0:
53
+ print(" No data for this segment.")
54
+ return
55
+
56
+ print(f" Mean: {avg_val:.4f} | Min: {min_val:.4f} | Max: {max_val:.4f}")
57
+ print(f" Q: p10={qs[0]:.2f} p50={qs[2]:.2f} p90={qs[4]:.2f} p99={qs[6]:.2f}")
58
 
59
+ except Exception as e:
60
+ print(f" Error calculating stats: {e}")
61
+ return
62
+
63
+ # 2. Print Buckets
64
+ query = f"""
65
+ SELECT
66
+ {bucket_case_sql} as bucket,
67
+ count() as cnt
68
+ FROM (
69
+ {subquery}
70
+ )
71
+ GROUP BY bucket
72
+ ORDER BY bucket
73
+ """
74
+ try:
75
+ rows = client.execute(query)
76
+ # total_count used for pct is the count_val from stats
77
+ print(" Buckets:")
78
+ for r in rows:
79
+ pct = (r[1] / count_val * 100) if count_val > 0 else 0
80
+ print(f" {r[0]}: {r[1]} ({pct:.1f}%)")
81
+ except Exception as e:
82
+ print(f" Error calculating buckets: {e}")
83
+
84
+ def get_filtered_metric_query(inner_query, cohort_sql):
85
+ """
86
+ Wraps the inner metric query to only include tokens in the cohort.
87
+ Assumes inner_query returns 'base_address' (or aliased) and 'val'.
88
+ If the inner query returns 'token_address', it should be handled.
89
+ Most of our queries return 'base_address' (from trades) or 'token_address' (from token_metrics).
90
+ We will normalize to use 'base_address' via subquery alias if needed, but simplest is
91
+ to filter on the outer Select.
92
+ """
93
+ # We need to know if the inner query produces 'base_address' or 'token_address'
94
+ # Currently our queries produce 'base_address' mostly, except token_metrics ones.
95
+ # Let's standardize inner queries in the main loop to alias the key column to 'join_key'
96
+
97
+ return f"""
98
+ SELECT * FROM (
99
+ {inner_query}
100
+ ) WHERE join_key IN ({cohort_sql})
101
+ """
102
+
103
+ import numpy as np
104
+ from models.vocabulary import RETURN_THRESHOLDS, MANIPULATED_CLASS_ID
105
+
106
+ def get_return_class_map(client):
107
+ """
108
+ Returns a dictionary mapping token_address -> class_id (int)
109
+ Filters out tokens with > 10,000x return.
110
+ Implements Dynamic Outlier Detection:
111
+ - Calculates Median Fees, Volume, Holders for each Class (1-4).
112
+ - Downgrades tokens with metrics < 10% of their class median to Class 5 (Manipulated).
113
+ """
114
+ print(" -> Fetching metrics for classification...")
115
+ # improved query to get fees/vol/holders
116
+ # aggregating trades for fees/vol to appear more robust than token_metrics snapshots
117
+ print(" -> Fetching metrics for classification...")
118
+ # SQL OPTIMIZATION:
119
+ # 1. Use token_metrics for Volume/Holders (Pre-computed).
120
+ # 2. Pre-aggregate trades for Fees in a subquery to avoid massive JOIN explosion.
121
+ query = """
122
+ SELECT
123
+ tm.token_address,
124
+ (argMax(tm.ath_price_usd, tm.updated_at) / 0.000004) as ret,
125
+ any(tr.fees) as fees,
126
+ argMax(tm.total_volume_usd, tm.updated_at) as vol,
127
+ argMax(tm.unique_holders, tm.updated_at) as holders
128
+ FROM token_metrics tm
129
+ LEFT JOIN (
130
+ SELECT
131
+ base_address,
132
+ sum(priority_fee + coin_creator_fee) as fees
133
+ FROM trades
134
+ GROUP BY base_address
135
+ ) tr ON tm.token_address = tr.base_address
136
+ GROUP BY tm.token_address
137
+ HAVING ret <= 10000
138
+ """
139
+ rows = client.execute(query)
140
+
141
+ # 1. Initial Classification
142
+ temp_map = {} # token -> {class_id, fees, vol, holders}
143
+
144
+ # Storage for stats calculation
145
+ class_stats = {i: {'fees': [], 'vol': [], 'holders': []} for i in range(len(RETURN_THRESHOLDS)-1)}
146
+
147
+ print(f" -> Initial classification of {len(rows)} tokens...")
148
+ for r in rows:
149
+ token_addr = r[0]
150
+ ret_val = r[1]
151
+ fees = r[2] or 0.0
152
+ vol = r[3] or 0.0
153
+ holders = r[4] or 0
154
 
155
+ class_id = -1
156
+ for i in range(len(RETURN_THRESHOLDS) - 1):
157
+ lower = RETURN_THRESHOLDS[i]
158
+ upper = RETURN_THRESHOLDS[i+1]
159
+ if ret_val >= lower and ret_val < upper:
160
+ class_id = i
161
+ break
162
+
163
+ if class_id != -1:
164
+ temp_map[token_addr] = {'id': class_id, 'fees': fees, 'vol': vol, 'holders': holders}
165
+ class_stats[class_id]['fees'].append(fees)
166
+ class_stats[class_id]['vol'].append(vol)
167
+ class_stats[class_id]['holders'].append(holders)
168
+
169
+ # 2. Calculate Medians & Thresholds
170
+ thresholds = {}
171
+ print(" -> Calculating Class Medians & Thresholds (< 10% of Median)...")
172
+ for i in range(1, 5): # Check classes 1, 2, 3, 4 (Profitable to PVE)
173
+ # Class 0 (Garbage) is not checked/filtered
174
+ if len(class_stats[i]['fees']) > 0:
175
+ med_fees = np.median(class_stats[i]['fees'])
176
+ med_vol = np.median(class_stats[i]['vol'])
177
+ med_holders = np.median(class_stats[i]['holders'])
178
+
179
+ thresholds[i] = {
180
+ 'fees': med_fees * 0.5,
181
+ 'vol': med_vol * 0.5,
182
+ 'holders': med_holders * 0.5
183
+ }
184
+ print(f" [Class {i}] Median Fees: {med_fees:.4f} (Thresh: {thresholds[i]['fees']:.4f}) | Median Vol: ${med_vol:.0f} (Thresh: ${thresholds[i]['vol']:.0f}) | Median Holders: {med_holders:.0f} (Thresh: {thresholds[i]['holders']:.0f})")
185
+ else:
186
+ thresholds[i] = {'fees': 0, 'vol': 0, 'holders': 0}
187
 
188
+ # 3. Reclassification
189
+ print(" -> Detecting Manipulated Outliers...")
190
+ final_map = {}
191
+ manipulated_count = 0
192
+
193
+ for token, data in temp_map.items():
194
+ cid = data['id']
195
+ # Only check if it's a "successful" class (ID > 0)
196
+ if cid > 0 and cid in thresholds:
197
+ t = thresholds[cid]
198
+ # Condition: If ANY metric is suspiciously low
199
+ is_manipulated = (data['fees'] < t['fees']) or (data['vol'] < t['vol']) or (data['holders'] < t['holders'])
200
+
201
+ if is_manipulated:
202
+ final_map[token] = MANIPULATED_CLASS_ID
203
+ manipulated_count += 1
204
+ else:
205
+ final_map[token] = cid
206
+ else:
207
+ final_map[token] = cid
208
+
209
+ print(f" -> Reclassification Complete. identified {manipulated_count} manipulated tokens.")
210
+ return final_map, thresholds
211
+
212
+ def analyze():
213
+ client = get_client()
214
+
215
+ print("=== SEGMENTED DISTRIBUTION ANALYSIS ===")
216
+
217
+ # 1. Get Classified Map AND Thresholds
218
+ class_map, thresholds = get_return_class_map(client)
219
+
220
+ # 2. Invert Map for easy lookups (still useful for counts or smaller segments)
221
+ segments_tokens = {}
222
+ for t, c in class_map.items():
223
+ if c not in segments_tokens:
224
+ segments_tokens[c] = []
225
+ segments_tokens[c].append(t)
226
+
227
+ # Define Labels
228
+ labels = {
229
+ 0: "0. Garbage (< 3x)",
230
+ 1: "1. Profitable (3x-10x)",
231
+ 2: "2. Good (10x-20x)",
232
+ 3: "3. Hyped (20x-100x)",
233
+ 4: "4. PVE (100x-10kx)",
234
+ MANIPULATED_CLASS_ID: "5. MANIPULATED (Fake Metrics)"
235
+ }
236
+
237
+ # Common SQL parts
238
+ # We need a robust base for the WHERE clause variables (fees, vol, holders)
239
+ # Since we can't easily alias in the WHERE clause of a subquery filter without re-joining,
240
+ # we will rely on a standardized CTE-like structure or just simpler subqueries in the condition.
241
+
242
+ # Efficient Token Metrics View
243
+ # We need to filter based on: ret, fees, vol, holders
244
+ # fees come from trades (sum), vol/holders/ret from token_metrics (argMax)
245
+
246
+ # To keep query size small, we define the criteria logic in SQL.
247
+ # But we need 'fees' which is an aggregate.
248
+ # So we define a base cohort query that computes these 4 values for EVERY token,
249
+ # and then wrap it with the WHERE clause.
250
+
251
+ base_cohort_source = """
252
+ SELECT
253
+ tm.token_address as join_key,
254
+ (argMax(tm.ath_price_usd, tm.updated_at) / 0.000004) as ret,
255
+ any(tr.fees) as fees,
256
+ argMax(tm.total_volume_usd, tm.updated_at) as vol,
257
+ argMax(tm.unique_holders, tm.updated_at) as holders
258
+ FROM token_metrics tm
259
+ LEFT JOIN (
260
+ SELECT base_address, sum(priority_fee + coin_creator_fee) as fees
261
+ FROM trades
262
+ GROUP BY base_address
263
+ ) tr ON tm.token_address = tr.base_address
264
+ GROUP BY tm.token_address
265
+ """
266
+
267
+ # Iterate through known classes
268
+ for cid in sorted(labels.keys()):
269
+ label = labels[cid]
270
+ tokens = segments_tokens.get(cid, [])
271
+ count = len(tokens)
272
 
273
+ print(f"\n\n==================================================")
274
+ print(f"SEGMENT: {label}")
275
+ print(f"==================================================")
276
+ print(f"Tokens in segment: {count}")
 
 
 
 
 
 
 
277
 
278
+ if count == 0:
279
+ continue
280
+
281
+ # Construct SQL Condition based on ID
282
+ condition = "1=0" # Default fail
283
 
284
+ if cid == 0:
285
+ # Garbage: Just Return < 3.
286
+ # Note: Technically it also includes tokens that might have been >3x but <10000x...
287
+ # BUT our Python/Map logic says Garbage is class 0.
288
+ # The only way to be class 0 in the map is if ret < 3.
289
+ # Downgraded tokens go to Class 5.
290
+ condition = "ret < 3"
291
+
292
+ elif cid == MANIPULATED_CLASS_ID:
293
+ # Manipulated:
294
+ # It's the collection of (Class K logic AND is_outlier)
295
+ sub_conds = []
296
+ for k in range(1, 5):
297
+ if k in thresholds:
298
+ t = thresholds[k]
299
+ # Range for Class K
300
+ lower = RETURN_THRESHOLDS[k]
301
+ upper = RETURN_THRESHOLDS[k+1]
302
+ # Outlier logic
303
+ sub_conds.append(f"(ret >= {lower} AND ret < {upper} AND (fees < {t['fees']} OR vol < {t['vol']} OR holders < {t['holders']}))")
304
+
305
+ if sub_conds:
306
+ condition = " OR ".join(sub_conds)
307
+
308
+ else:
309
+ # Normal Classes 1-4
310
+ if cid in thresholds:
311
+ t = thresholds[cid]
312
+ lower = RETURN_THRESHOLDS[cid]
313
+ upper = RETURN_THRESHOLDS[cid+1]
314
+ # Valid logic: In Range AND NOT Outlier
315
+ condition = f"(ret >= {lower} AND ret < {upper} AND fees >= {t['fees']} AND vol >= {t['vol']} AND holders >= {t['holders']})"
316
+
317
+ # Final Cohort SQL: Select keys satisfying the condition
318
+ # We wrap the base source
319
+ cohort_sql = f"""
320
+ SELECT join_key FROM (
321
+ {base_cohort_source}
322
+ ) WHERE {condition}
323
+ """
324
+
325
+ # Helper to construct the full condition "join_key IN (...)"
326
+ # NOW we use the subquery instead of a literal list
327
+ def make_query(inner, cohort_subquery):
328
+ return f"""
329
+ SELECT * FROM (
330
+ {inner}
331
+ ) WHERE join_key IN (
332
+ {cohort_subquery}
333
+ )
334
+ """
335
+
336
+ # --- Metrics Definitions ---
337
+
338
+ # 1. Fees (SOL)
339
+ fees_inner = """
340
+ SELECT base_address as join_key, sum(priority_fee + coin_creator_fee) as val
341
+ FROM trades
342
+ GROUP BY base_address
343
+ """
344
+ fees_buckets = """
345
  case
346
+ when val < 0.001 then '1. < 0.001 SOL'
347
+ when val >= 0.001 AND val < 0.01 then '2. 0.001 - 0.01'
348
+ when val >= 0.01 AND val < 0.1 then '3. 0.01 - 0.1'
349
+ when val >= 0.1 AND val < 1 then '4. 0.1 - 1'
350
+ when val >= 1 then '5. > 1 SOL'
351
  else 'Unknown'
352
+ end
 
 
 
 
 
 
 
 
 
353
  """
354
+ print_distribution_stats(client, "Total Fees (SOL)", make_query(fees_inner, cohort_sql), fees_buckets)
 
 
355
 
356
+ # 2. Volume (USD)
357
+ vol_inner = """
358
+ SELECT base_address as join_key, sum(total_usd) as val
359
+ FROM trades
360
+ GROUP BY base_address
361
+ """
362
+ vol_buckets = """
 
363
  case
364
+ when val < 1000 then '1. < $1k'
365
+ when val >= 1000 AND val < 10000 then '2. $1k - $10k'
366
+ when val >= 10000 AND val < 100000 then '3. $10k - $100k'
367
+ when val >= 100000 AND val < 1000000 then '4. $100k - $1M'
368
+ when val >= 1000000 then '5. > $1M'
369
  else 'Unknown'
370
+ end
371
+ """
372
+ print_distribution_stats(client, "Total Volume (USD)", make_query(vol_inner, cohort_sql), vol_buckets)
373
+
374
+ # 3. Unique Holders
375
+ holders_inner = """
376
+ SELECT token_address as join_key, argMax(unique_holders, updated_at) as val
377
+ FROM token_metrics
378
+ GROUP BY token_address
379
+ """
380
+ holders_buckets = """
381
+ case
382
+ when val < 10 then '1. < 10'
383
+ when val >= 10 AND val < 50 then '2. 10 - 50'
384
+ when val >= 50 AND val < 100 then '3. 50 - 100'
385
+ when val >= 100 AND val < 500 then '4. 100 - 500'
386
+ when val >= 500 then '5. > 500'
387
+ else 'Unknown'
388
+ end
389
+ """
390
+ print_distribution_stats(client, "Unique Holders", make_query(holders_inner, cohort_sql), holders_buckets)
391
+
392
+ # 4. Snipers % Supply
393
+ snipers_inner = """
394
+ SELECT
395
+ m.base_address as join_key,
396
+ (m.val / t.total_supply * 100) as val
397
+ FROM (
398
+ SELECT
399
+ base_address,
400
+ sumIf(base_amount, buyer_rank <= 70) as val
401
+ FROM (
402
+ SELECT
403
+ base_address,
404
+ base_amount,
405
+ dense_rank() OVER (PARTITION BY base_address ORDER BY min_slot, min_idx) as buyer_rank
406
+ FROM (
407
+ SELECT
408
+ base_address,
409
+ maker,
410
+ min(slot) as min_slot,
411
+ min(transaction_index) as min_idx,
412
+ sum(base_amount) as base_amount
413
+ FROM trades
414
+ WHERE trade_type = 0
415
+ GROUP BY base_address, maker
416
+ )
417
+ )
418
+ GROUP BY base_address
419
+ ) m
420
+ JOIN (
421
+ SELECT token_address, argMax(total_supply, updated_at) as total_supply
422
+ FROM tokens
423
+ GROUP BY token_address
424
+ ) t ON m.base_address = t.token_address
425
+ WHERE t.total_supply > 0
426
+ """
427
+ pct_buckets = """
428
+ case
429
+ when val < 1 then '1. < 1%'
430
+ when val >= 1 AND val < 5 then '2. 1% - 5%'
431
+ when val >= 5 AND val < 10 then '3. 5% - 10%'
432
+ when val >= 10 AND val < 20 then '4. 10% - 20%'
433
+ when val >= 20 AND val < 50 then '5. 20% - 50%'
434
+ when val >= 50 then '6. > 50%'
435
+ else 'Unknown'
436
+ end
437
+ """
438
+ print_distribution_stats(client, "Snipers % Supply (Top 70)", make_query(snipers_inner, cohort_sql), pct_buckets)
439
+
440
+ # 5. Bundled % Supply
441
+ bundled_inner = """
442
+ SELECT
443
+ m.base_address as join_key,
444
+ (m.val / t.total_supply * 100) as val
445
+ FROM (
446
+ SELECT
447
+ t.base_address,
448
+ sum(t.base_amount) as val
449
+ FROM trades t
450
+ JOIN (
451
+ SELECT base_address, min(slot) as min_slot
452
+ FROM trades
453
+ GROUP BY base_address
454
+ ) m ON t.base_address = m.base_address AND t.slot = m.min_slot
455
+ WHERE t.trade_type = 0
456
+ GROUP BY t.base_address
457
+ ) m
458
+ JOIN (
459
+ SELECT token_address, argMax(total_supply, updated_at) as total_supply
460
+ FROM tokens
461
+ GROUP BY token_address
462
+ ) t ON m.base_address = t.token_address
463
+ WHERE t.total_supply > 0
464
+ """
465
+ print_distribution_stats(client, "Bundled % Supply", make_query(bundled_inner, cohort_sql), pct_buckets)
466
+
467
+ # 6. Dev Holding % Supply
468
+ dev_inner = """
469
  SELECT
470
+ t.token_address as join_key,
471
+ (wh.current_balance / (t.total_supply / pow(10, t.decimals)) * 100) as val
472
+ FROM (
473
+ SELECT token_address, argMax(creator_address, updated_at) as creator_address, argMax(total_supply, updated_at) as total_supply, argMax(decimals, updated_at) as decimals
474
+ FROM tokens
475
+ GROUP BY token_address
476
+ ) t
477
+ JOIN (
478
+ SELECT mint_address, wallet_address, argMax(current_balance, updated_at) as current_balance
479
+ FROM wallet_holdings
480
+ GROUP BY mint_address, wallet_address
481
+ ) wh ON t.token_address = wh.mint_address AND t.creator_address = wh.wallet_address
482
+ WHERE t.total_supply > 0
483
+ """
484
+ print_distribution_stats(client, "Dev Holding % Supply", make_query(dev_inner, cohort_sql), pct_buckets)
485
+
486
+ # 7. Insiders % Supply
487
+ insiders_inner = """
488
+ SELECT
489
+ wh.mint_address as join_key,
490
+ (sum(wh.current_balance) / (t.total_supply / pow(10, t.decimals)) * 100) as val
491
+ FROM (
492
+ SELECT mint_address, wallet_address, argMax(current_balance, updated_at) as current_balance
493
+ FROM wallet_holdings
494
+ GROUP BY mint_address, wallet_address
495
+ ) wh
496
+ JOIN (
497
+ SELECT wallet_address,
498
+ argMax(total_buys_count, updated_at) as buys,
499
+ argMax(transfers_in_count, updated_at) as transfers,
500
+ argMax(spl_transfers_in_count, updated_at) as spl_transfers
501
+ FROM wallet_profile_metrics
502
+ GROUP BY wallet_address
503
+ ) wpm ON wh.wallet_address = wpm.wallet_address
504
+ JOIN (
505
+ SELECT token_address, argMax(total_supply, updated_at) as total_supply, argMax(decimals, updated_at) as decimals
506
+ FROM tokens
507
+ GROUP BY token_address
508
+ ) t ON wh.mint_address = t.token_address
509
+ WHERE wpm.buys = 0 AND (wpm.transfers > 0 OR wpm.spl_transfers > 0) AND t.total_supply > 0
510
+ GROUP BY wh.mint_address, t.total_supply, t.decimals
511
+ """
512
+ print_distribution_stats(client, "Insiders % Supply", make_query(insiders_inner, cohort_sql), pct_buckets)
513
+
514
+ # 8. Time to ATH (Seconds)
515
+ time_ath_inner = """
516
+ SELECT
517
+ base_address as join_key,
518
+ (argMax(timestamp, price_usd) - min(timestamp)) as val
519
  FROM trades
520
  GROUP BY base_address
 
 
 
521
  """
522
+ time_ath_buckets = """
523
+ case
524
+ when val < 5 then '1. < 5s'
525
+ when val >= 5 AND val < 30 then '2. 5s - 30s'
526
+ when val >= 30 AND val < 60 then '3. 30s - 1m'
527
+ when val >= 60 AND val < 300 then '4. 1m - 5m'
528
+ when val >= 300 AND val < 3600 then '5. 5m - 1h'
529
+ when val >= 3600 then '6. > 1h'
530
+ else 'Unknown'
531
+ end
532
+ """
533
+ print_distribution_stats(client, "Time to ATH (Seconds)", make_query(time_ath_inner, cohort_sql), time_ath_buckets)
534
 
 
 
535
 
536
  if __name__ == "__main__":
537
  analyze()
scripts/analyze_hyperparams.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import numpy as np
5
+ import argparse
6
+ from tqdm import tqdm
7
+ from datetime import datetime, timezone
8
+ from collections import defaultdict
9
+
10
+
11
+ # Add project root to path
12
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
13
+ from data.data_loader import OracleDataset, DataFetcher
14
+
15
+ import os
16
+ import sys
17
+ import numpy as np
18
+ import argparse
19
+ from tqdm import tqdm
20
+ from datetime import datetime, timezone
21
+ import collections
22
+
23
+ # Add project root to path
24
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
25
+ from data.data_loader import DataFetcher
26
+
27
+ import os
28
+ import sys
29
+ import numpy as np
30
+ import argparse
31
+ from tqdm import tqdm
32
+ from datetime import datetime, timezone
33
+ import collections
34
+ from dotenv import load_dotenv
35
+ from clickhouse_driver import Client as ClickHouseClient
36
+ from neo4j import GraphDatabase
37
+
38
+ # Add project root to path
39
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
40
+ from data.data_loader import DataFetcher
41
+
42
+ def parse_args():
43
+ parser = argparse.ArgumentParser(description="Analyze dataset to tune hyperparameters (Horizons, Seq Len)")
44
+ parser.add_argument("--max_samples", type=int, default=5000, help="Max samples to analyze")
45
+ parser.add_argument("--token_address", type=str, default=None, help="Specific token address to analyze")
46
+ return parser.parse_args()
47
+
48
+ def main():
49
+ load_dotenv()
50
+ args = parse_args()
51
+
52
+ print("--- Hyperparameter Calibration Analysis (SQL) ---")
53
+
54
+ # DB Connection
55
+ ch_host = os.getenv("CLICKHOUSE_HOST", "localhost")
56
+ ch_port = int(os.getenv("CLICKHOUSE_NATIVE_PORT", 9000))
57
+ neo_uri = os.getenv("NEO4J_URI", "bolt://localhost:7687")
58
+ neo_user = os.getenv("NEO4J_USER", "neo4j")
59
+ neo_pass = os.getenv("NEO4J_PASSWORD", "password")
60
+
61
+ print(f"Connecting to ClickHouse at {ch_host}:{ch_port}...")
62
+ clickhouse_client = ClickHouseClient(host=ch_host, port=ch_port)
63
+
64
+ print(f"Connecting to Neo4j at {neo_uri}...")
65
+ neo4j_driver = GraphDatabase.driver(neo_uri, auth=(neo_user, neo_pass))
66
+
67
+ # 1. Initialize DataFetcher
68
+ fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
69
+ print("DataFetcher initialized.")
70
+
71
+ # 2. Fetch Sample Mints
72
+ if args.token_address:
73
+ print(f"Analyzing specific token: {args.token_address}")
74
+ # Try to find mint timestamp
75
+ query = f"SELECT mint_address, timestamp FROM mints WHERE mint_address = '{args.token_address}'"
76
+ mints = fetcher.db_client.execute(query)
77
+ if not mints:
78
+ print("Token not found in mints table. Trying to use first trade timestamp...")
79
+ # Fallback if not in mints table
80
+ q2 = f"SELECT base_address, min(timestamp) FROM trades WHERE base_address = '{args.token_address}' GROUP BY base_address"
81
+ mints = fetcher.db_client.execute(q2)
82
+
83
+ if not mints:
84
+ print("Token not found in trades either (or no trades). Exiting.")
85
+ return
86
+
87
+ else:
88
+ print(f"Fetching {args.max_samples} sample tokens...")
89
+ # Fetch random mints
90
+ query = f"""
91
+ SELECT mint_address, timestamp FROM mints
92
+ ORDER BY rand()
93
+ LIMIT {args.max_samples}
94
+ """
95
+ mints = fetcher.db_client.execute(query)
96
+ print(f"Fetched {len(mints)} tokens.")
97
+
98
+ # Metrics to collect
99
+ lifespans = [] # Time from mint to last trade
100
+ time_to_ath = [] # Time from mint to highest price
101
+
102
+ # Sequence Length estimations
103
+ windows_to_test = [5, 10, 30, 60] # Minutes
104
+ event_counts = {w: [] for w in windows_to_test}
105
+ full_history_counts = []
106
+
107
+ print(f"Analyzing trades for {len(mints)} tokens...")
108
+
109
+ for mint_addr, mint_ts in tqdm(mints):
110
+ try:
111
+ if isinstance(mint_ts, datetime) and mint_ts.tzinfo is None:
112
+ mint_ts = mint_ts.replace(tzinfo=timezone.utc)
113
+ t0 = mint_ts.timestamp()
114
+
115
+ # Fetch ALL trades for this token
116
+ # We don't need full enrichments, just timestamp and price
117
+ # Args: token_addr, T_cutoff, count_threshold, early_lim, recent_lim, full_history
118
+ now_ts = datetime.now(timezone.utc)
119
+ trades, _, _ = fetcher.fetch_trades_for_token(mint_addr, now_ts, 0, 0, 0, full_history=True)
120
+
121
+ if not trades: continue
122
+
123
+ # Trades are usually sorted, but ensure
124
+ trades.sort(key=lambda x: x['timestamp'])
125
+
126
+ # Lifespan
127
+ last_ts = trades[-1]['timestamp'].timestamp()
128
+ lifespans.append(last_ts - t0)
129
+
130
+ # Time to ATH
131
+ max_price = -1.0
132
+ ath_ts = 0.0
133
+
134
+ valid_trades = []
135
+ for t in trades:
136
+ p = float(t.get('price_usd', 0.0))
137
+ # Basic filter for garbage prints
138
+ if p > 0:
139
+ valid_trades.append(t)
140
+ if p > max_price:
141
+ max_price = p
142
+ ath_ts = t['timestamp'].timestamp()
143
+
144
+ if max_price > 0:
145
+ time_to_ath.append(ath_ts - t0)
146
+
147
+ # --- Sequence Length Metrics ---
148
+ full_history_counts.append(len(valid_trades))
149
+
150
+ # Windowed counts
151
+ counts_in_window = {w: 0 for w in windows_to_test}
152
+
153
+ for t in valid_trades:
154
+ ts_val = t['timestamp'].timestamp()
155
+ elapsed_min = (ts_val - t0) / 60.0
156
+
157
+ for w in windows_to_test:
158
+ if elapsed_min <= w:
159
+ counts_in_window[w] += 1
160
+
161
+ for w in windows_to_test:
162
+ event_counts[w].append(counts_in_window[w])
163
+
164
+ except Exception as e:
165
+ print(f"Error processing {mint_addr}: {e}")
166
+ import traceback
167
+ traceback.print_exc()
168
+ pass
169
+
170
+ # --- Stats Calculation ---
171
+ def print_stats(name, data):
172
+ if not data:
173
+ print(f"{name}: No Data")
174
+ return
175
+ # Convert to numpy array for easier filtering if needed, though they are lists
176
+ arr = np.array(data)
177
+ p25 = np.percentile(arr, 25)
178
+ p50 = np.percentile(arr, 50)
179
+ p75 = np.percentile(arr, 75)
180
+ p90 = np.percentile(arr, 90)
181
+ p95 = np.percentile(arr, 95)
182
+ p99 = np.percentile(arr, 99)
183
+ max_val = np.max(arr)
184
+ print(f"[{name}]")
185
+ print(f" Mean: {np.mean(arr):.2f} | Median: {p50:.2f} | Max: {max_val:.2f}")
186
+ print(f" 25%: {p25:.2f} | 75%: {p75:.2f} | 90%: {p90:.2f} | 95%: {p95:.2f} | 99%: {p99:.2f}")
187
+
188
+ print("\n" + "="*40)
189
+ print("RESULTS (ALL TOKENS)")
190
+ print("="*40)
191
+
192
+ # Time Stats
193
+ lifespans_min = [x/60.0 for x in lifespans]
194
+ time_to_ath_min = [x/60.0 for x in time_to_ath]
195
+
196
+ print_stats("Token Lifespan (Minutes)", lifespans_min)
197
+ print("\n")
198
+ print_stats("Time to ATH (Minutes)", time_to_ath_min)
199
+
200
+ print("\n" + "-"*20)
201
+ print("SEQUENCE LENGTHS (Trades Only)")
202
+ print("-"*20)
203
+
204
+ print_stats("Full History Length", full_history_counts)
205
+
206
+ for w in windows_to_test:
207
+ print("\n")
208
+ print_stats(f"Trades in First {w} Minutes", event_counts[w])
209
+
210
+ # --- High Activity Subset ---
211
+ print("\n" + "="*40)
212
+ print("RESULTS (HIGH ACTIVITY SUBSET)")
213
+ print("Filter: > 50 trades AND > 5 min lifespan")
214
+ print("="*40)
215
+
216
+ # Filter indices
217
+ valid_indices = []
218
+ for i, count in enumerate(full_history_counts):
219
+ if count > 50 and lifespans_min[i] > 5.0:
220
+ valid_indices.append(i)
221
+
222
+ if not valid_indices:
223
+ print("No high activity tokens found.")
224
+ else:
225
+ print(f"Found {len(valid_indices)} high activity tokens out of {len(full_history_counts)}.")
226
+
227
+ subset_lifespans = [lifespans_min[i] for i in valid_indices]
228
+ subset_ath = [time_to_ath_min[i] for i in valid_indices if i < len(time_to_ath_min)] # careful with length if sizes differ? they shouldn't by logic, but time_to_ath depends on if trade > 0
229
+
230
+ # indices are aligned with loop order
231
+ # But wait, time_to_ath was appended only if max_price > 0.
232
+ # This misalignment is risky.
233
+
234
+ # Better: Store dicts or tuples in the main loop instead of parallel lists.
235
+ # Quick fix: Just recalc stats on lists is hard if not aligned?
236
+ # Actually time_to_ath might be shorter than lifespans.
237
+ # Let's just print what we can, assuming simple filtering on `event_counts` which aligns 1:1 with loop (except exceptions).
238
+
239
+ # Re-collect logic for subsets is cleaner if we store objects.
240
+ # But let's just do Event Counts which are critical for seq_len.
241
+
242
+ subset_history = [full_history_counts[i] for i in valid_indices]
243
+ print_stats("Subset: Full History Length", subset_history)
244
+
245
+ for w in windows_to_test:
246
+ subset_w = [event_counts[w][i] for i in valid_indices]
247
+ print("\n")
248
+ print_stats(f"Subset: Trades in First {w} Min", subset_w)
249
+
250
+ print("\nRecommendation Logic:")
251
+ print("1. Horizons: Look at 'Time to ATH' p90 (or p90 of Subset).")
252
+ print("2. Max Seq Len: Look at 'Trades in First X Minutes' (X ~= Max Horizon).")
253
+
254
+ if __name__ == "__main__":
255
+ main()
scripts/cache_dataset.py CHANGED
@@ -14,104 +14,130 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14
 
15
  from data.data_loader import OracleDataset
16
  from data.data_fetcher import DataFetcher
 
 
17
  from clickhouse_driver import Client as ClickHouseClient
18
  from neo4j import GraphDatabase
19
 
20
- # Load environment variables
21
- load_dotenv()
22
-
23
- # --- Configuration ---
24
- CLICKHOUSE_HOST = os.getenv("CLICKHOUSE_HOST", "localhost")
25
- CLICKHOUSE_PORT = int(os.getenv("CLICKHOUSE_PORT", 9000))
26
- CLICKHOUSE_USER = os.getenv("CLICKHOUSE_USER") or "default"
27
- CLICKHOUSE_PASSWORD = os.getenv("CLICKHOUSE_PASSWORD") or ""
28
- CLICKHOUSE_DATABASE = os.getenv("CLICKHOUSE_DATABASE", "default")
29
-
30
- NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
31
- NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
32
- NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password")
33
-
34
- CACHE_DIR = os.getenv("CACHE_DIR", "/workspace/apollo/data/cache")
35
-
36
  def main():
37
- parser = argparse.ArgumentParser(description="Pre-cache dataset samples.")
38
- parser.add_argument("--max_samples", type=int, default=-1, help="Number of samples to cache. Set to -1 to process all available.")
39
-
40
- parser.add_argument("--start_date", type=str, default=None, help="Start date for filtering mints (YYYY-MM-DD).")
41
- parser.add_argument("--ohlc_stats_path", type=str, default=None, help="Path to OHLC stats JSON.")
42
- parser.add_argument("--min_trade_usd", type=float, default=0.0, help="Minimum trade USD value.")
 
 
 
 
 
 
 
 
 
43
 
44
  args = parser.parse_args()
45
 
46
- # Handle -1 as unlimited (None)
47
- max_samples = args.max_samples if args.max_samples != -1 else None
48
-
49
- # Create cache directory if it doesn't exist
50
- output_dir = Path(CACHE_DIR)
51
  output_dir.mkdir(parents=True, exist_ok=True)
52
 
53
  start_date_dt = None
54
  if args.start_date:
55
- start_date_dt = datetime.datetime.strptime(args.start_date, "%Y-%m-%d").replace(tzinfo=datetime.timezone.utc)
56
-
57
- # --- 1. Set up database connections ---
 
 
 
58
  try:
59
- print("INFO: Connecting to ClickHouse...")
60
- clickhouse_client = ClickHouseClient(host=CLICKHOUSE_HOST, port=CLICKHOUSE_PORT, user=CLICKHOUSE_USER, password=CLICKHOUSE_PASSWORD, database=CLICKHOUSE_DATABASE)
61
- print("INFO: Connecting to Neo4j...")
62
- neo4j_driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))
63
- except Exception as e:
64
- print(f"ERROR: Failed to connect to databases: {e}", file=sys.stderr)
65
- sys.exit(1)
66
-
67
- # --- 2. Initialize DataFetcher and OracleDataset ---
68
- data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
69
 
70
- dataset = OracleDataset(
71
- data_fetcher=data_fetcher,
72
- max_samples=max_samples,
73
- start_date=start_date_dt,
74
-
75
- ohlc_stats_path=args.ohlc_stats_path,
76
- horizons_seconds=[60, 300, 900, 1800, 3600],
77
- quantiles=[0.5],
78
- min_trade_usd=args.min_trade_usd
79
- )
80
-
81
- if len(dataset) == 0:
82
- print("WARNING: Dataset initialization resulted in 0 samples. Nothing to cache.")
83
- return
84
-
85
- # --- 3. Iterate and cache each item ---
86
- print(f"INFO: Starting to generate and cache {len(dataset)} samples...")
87
- skipped_count = 0
88
- for i in tqdm(range(len(dataset)), desc="Caching samples"):
89
- try:
90
- item = dataset.__cacheitem__(i)
91
- if item is None:
92
- skipped_count += 1
93
- continue
94
- output_path = output_dir / f"sample_{i}.pt"
95
- torch.save(item, output_path)
96
- except Exception as e:
97
- error_msg = str(e)
98
- # If a FATAL error occurs (e.g. persistent DB auth failure), stop the script immediately.
99
- if "FATAL" in error_msg or "AuthenticationRateLimit" in error_msg:
100
- print(f"\nCRITICAL: Fatal error encountered processing sample {i}. Stopping execution.\nError: {e}", file=sys.stderr)
101
- sys.exit(1)
102
-
103
- print(f"\nERROR: Failed to generate or save sample {i} for mint '{dataset.sampled_mints[i]['mint_address']}'. Error: {e}", file=sys.stderr)
104
- # print trackback
105
- import traceback
106
- traceback.print_exc()
107
- skipped_count += 1
108
- continue
109
 
110
- print(f"\n--- Caching Complete ---\nSuccessfully cached: {len(dataset) - skipped_count} items.\nSkipped: {skipped_count} items.\nCache location: {output_dir.resolve()}")
111
-
112
- # --- 4. Close connections ---
113
- clickhouse_client.disconnect()
114
- neo4j_driver.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  if __name__ == "__main__":
117
  main()
 
14
 
15
  from data.data_loader import OracleDataset
16
  from data.data_fetcher import DataFetcher
17
+ from scripts.analyze_distribution import get_return_class_map
18
+
19
  from clickhouse_driver import Client as ClickHouseClient
20
  from neo4j import GraphDatabase
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def main():
23
+ load_dotenv()
24
+
25
+ parser = argparse.ArgumentParser(description="Cache dataset samples for training.")
26
+ parser.add_argument("--output_dir", type=str, default="data/cache", help="Directory to save cached samples")
27
+ parser.add_argument("--max_samples", type=int, default=None, help="Maximum number of samples to generate")
28
+ parser.add_argument("--start_date", type=str, default=None, help="Start date (YYYY-MM-DD) for fetching new mints")
29
+ parser.add_argument("--ohlc_stats_path", type=str, default="data/ohlc_stats.npz")
30
+ parser.add_argument("--min_trade_usd", type=float, default=0.0)
31
+
32
+ # DB Args
33
+ parser.add_argument("--clickhouse_host", type=str, default=os.getenv("CLICKHOUSE_HOST", "localhost"))
34
+ parser.add_argument("--clickhouse_port", type=int, default=int(os.getenv("CLICKHOUSE_PORT", 9000)))
35
+ parser.add_argument("--neo4j_uri", type=str, default=os.getenv("NEO4J_URI", "bolt://localhost:7687"))
36
+ parser.add_argument("--neo4j_user", type=str, default=os.getenv("NEO4J_USER", "neo4j"))
37
+ parser.add_argument("--neo4j_password", type=str, default=os.getenv("NEO4J_PASSWORD", "password"))
38
 
39
  args = parser.parse_args()
40
 
41
+ output_dir = Path(args.output_dir)
 
 
 
 
42
  output_dir.mkdir(parents=True, exist_ok=True)
43
 
44
  start_date_dt = None
45
  if args.start_date:
46
+ start_date_dt = datetime.datetime.strptime(args.start_date, "%Y-%m-%d")
47
+
48
+ print(f"INFO: Initializing DB Connections...")
49
+ clickhouse_client = ClickHouseClient(host=args.clickhouse_host, port=args.clickhouse_port)
50
+ neo4j_driver = GraphDatabase.driver(args.neo4j_uri, auth=(args.neo4j_user, args.neo4j_password))
51
+
52
  try:
53
+ # --- 2. Initialize DataFetcher and OracleDataset ---
54
+ data_fetcher = DataFetcher(clickhouse_client=clickhouse_client, neo4j_driver=neo4j_driver)
55
+
56
+ # Pre-fetch the Return Class Map
57
+ # tokens not in this map (e.g. >10k x) are INVALID and will be skipped
58
+ print("INFO: Fetching Return Classification Map...")
59
+ return_class_map, thresholds = get_return_class_map(clickhouse_client)
60
+ print(f"INFO: Loaded {len(return_class_map)} valid classified tokens.")
 
 
61
 
62
+ dataset = OracleDataset(
63
+ data_fetcher=data_fetcher,
64
+ max_samples=args.max_samples,
65
+ start_date=start_date_dt,
66
+ ohlc_stats_path=args.ohlc_stats_path,
67
+ horizons_seconds=[60, 180, 300, 600, 1800, 3600, 7200],
68
+ quantiles=[0.5],
69
+ min_trade_usd=args.min_trade_usd
70
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ if len(dataset) == 0:
73
+ print("WARNING: Dataset initialization resulted in 0 samples. Nothing to cache.")
74
+ return
75
+
76
+ # --- 3. Iterate and cache each item ---
77
+ print(f"INFO: Starting to generate and cache {len(dataset)} samples...")
78
+
79
+ metadata_path = output_dir / "metadata.jsonl"
80
+ print(f"INFO: Writing metadata to {metadata_path}")
81
+
82
+ skipped_count = 0
83
+ filtered_count = 0
84
+ cached_count = 0
85
+
86
+ # Open metadata file in append mode
87
+ with open(metadata_path, 'a') as meta_f:
88
+ for i in tqdm(range(len(dataset)), desc="Caching samples"):
89
+ mint_addr = dataset.sampled_mints[i]['mint_address']
90
+
91
+ # 1. Filter Check
92
+ if mint_addr not in return_class_map:
93
+ # Token is effectively "filtered out" (e.g. > 10,000x return or missing metrics)
94
+ filtered_count += 1
95
+ continue
96
+
97
+ class_id = return_class_map[mint_addr]
98
+
99
+ try:
100
+ item = dataset.__cacheitem__(i)
101
+ if item is None:
102
+ skipped_count += 1
103
+ continue
104
+
105
+ filename = f"sample_{i}.pt"
106
+ output_path = output_dir / filename
107
+ torch.save(item, output_path)
108
+
109
+ # Write metadata entry
110
+ # Minimizing IO overhead by keeping line short
111
+ meta_entry = {"file": filename, "class_id": class_id}
112
+ meta_f.write(json.dumps(meta_entry) + "\n")
113
+
114
+ cached_count += 1
115
+
116
+ except Exception as e:
117
+ error_msg = str(e)
118
+ # If a FATAL error occurs (e.g. persistent DB auth failure), stop the script immediately.
119
+ if "FATAL" in error_msg or "AuthenticationRateLimit" in error_msg:
120
+ print(f"\nCRITICAL: Fatal error encountered processing sample {i}. Stopping execution.\nError: {e}", file=sys.stderr)
121
+ sys.exit(1)
122
+
123
+ print(f"\nERROR: Failed to generate or save sample {i} for mint '{mint_addr}'. Error: {e}", file=sys.stderr)
124
+ # print trackback
125
+ import traceback
126
+ traceback.print_exc()
127
+ skipped_count += 1
128
+ continue
129
+
130
+ print(f"\n--- Caching Complete ---")
131
+ print(f"Successfully cached: {cached_count} items.")
132
+ print(f"Filtered (Invalid/High Return): {filtered_count} items.")
133
+ print(f"Skipped (Errors/Empty): {skipped_count} items.")
134
+ print(f"Cache location: {output_dir.resolve()}")
135
+ print(f"Metadata location: {metadata_path.resolve()}")
136
+
137
+ finally:
138
+ # --- 4. Close connections ---
139
+ clickhouse_client.disconnect()
140
+ neo4j_driver.close()
141
 
142
  if __name__ == "__main__":
143
  main()
scripts/debug_db_counts.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import clickhouse_connect
4
+ from dotenv import load_dotenv
5
+
6
+ load_dotenv()
7
+
8
+ def check_max_trades():
9
+ try:
10
+ client = clickhouse_connect.get_client(
11
+ host=os.getenv("CLICKHOUSE_HOST"),
12
+ port=int(os.getenv("CLICKHOUSE_HTTP_PORT")),
13
+ secure=False
14
+ )
15
+
16
+ print("Connected to ClickHouse.")
17
+
18
+ # 1. Find the token with the most trades
19
+ print("Querying max trade count per token (this might take a moment)...")
20
+ query = """
21
+ SELECT base_address, count(*) as c
22
+ FROM trades
23
+ GROUP BY base_address
24
+ ORDER BY c DESC
25
+ LIMIT 5
26
+ """
27
+ result = client.query(query)
28
+
29
+ print("Top 5 Tokens by Trade Count:")
30
+ for row in result.result_rows:
31
+ print(f"Token: {row[0]}, Count: {row[1]}")
32
+
33
+ except Exception as e:
34
+ print(f"Error: {e}")
35
+
36
+ if __name__ == "__main__":
37
+ check_max_trades()
t.json CHANGED
@@ -1,47 +1,11 @@
1
- "newPairs": {
2
 
3
- "fees": {
4
- "max": null,
5
- "min": null
6
- },
7
- "txns": {
8
- "max": null,
9
- "min": null
10
- },
11
- "bundle": {
12
- "max": null,
13
- "min": null
14
- },
15
- "volume": {
16
- "max": null,
17
- "min": null
18
- },
19
- "holders": {
20
- "max": null,
21
- "min": null
22
- },
23
- "numBuys": {
24
- "max": null,
25
- "min": null
26
- },
27
- "snipers": {
28
- "max": null,
29
- "min": null
30
- },
31
- "insiders": {
32
- "max": null,
33
- "min": null
34
- },
35
- "numSells": {
36
- "max": null,
37
- "min": null
38
- },
39
- "devHolding": {
40
- "max": null,
41
- "min": null
42
- },
43
- "top10Holders": {
44
- "max": null,
45
- "min": null
46
- },
47
- },
 
 
1
 
2
+ "TotalAggregetedFees"
3
+ "TotalSupplyBoughtByBundledTxns"
4
+ "TotalVolume"
5
+ "TotalUniqueHolders"
6
+ "TotalnumBuys"
7
+ "TotalSupplyBoughtBySnipers (first 70 unique wallets)"
8
+ "TotalSupplyHeldByInsiders"
9
+ "TotalnumSells"
10
+ "TotalDevHoldingSupply"
11
+ "totalSupplyHeldByTop10Holders"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
test.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Hyperparameter Analysis & Recommendations
2
+ Objective
3
+ Determine data-driven values for --max_seq_len and --horizons_seconds to optimize model training.
4
+
5
+ Analysis Findings
6
+ 1. Trade Volume Distribution
7
+ General Population (Bias towards Rugs): 99% of tokens have fewer than 1,300 trades in their entire lifetime.
8
+ High-Activity Tokens (Successful Launches): Verified against token HWVY....
9
+ Total Trades: ~300,000.
10
+ First 60 Minutes: ~3,720 trades.
11
+ Rate: Approx. 60-100 trades/minute during the initial pump.
12
+ 2. Time-to-ATH (All-Time High)
13
+ Median: ~3 seconds (Immediate dump/failure).
14
+ 90th Percentile: ~2.6 minutes.
15
+ 99th Percentile: ~90 minutes.
16
+ Conclusion: A model needs to observe at least the first 90 minutes to capture the "peak" behavior of the most successful 1% of tokens.
17
+ Recommendations
18
+ Max Sequence Length (--max_seq_len)
19
+ Recommendation: 8192
20
+
21
+ Logic:
22
+ High-volume tokens generate ~3,700 trades in the first hour.
23
+ To cover the critical 90-minute window (Time-to-ATH 99th percentile) for a high-volume token: 3700 * 1.5 = 5550 trades.
24
+ Adding buffer for liquidity events and higher-intensity bursts: 8192 (nearest power of 2).
25
+ This length is sufficient to capture:
26
+ 2+ hours of data for high-activity tokens.
27
+ The entire lifecycle for >99% of all tokens.
28
+ Prediction Horizons (--horizons_seconds)
29
+ Recommendation: 30, 60, 300, 600, 1800, 3600, 7200 (30s, 1m, 5m, 10m, 30m, 1h, 2h)
30
+
31
+ Logic:
32
+ Short-term (30s - 5m): Crucial for immediate volatility and scalping predictions, especially given the median lifespan is extremely short.
33
+ Medium-term (10m - 30m): Captures the trend development for "standard" rugs (90th percentile < 3 min, but tails extend).
34
+ Long-term (1h - 2h): Essential for the 1% of successful tokens where ATH occurs around 90 minutes.
train.py CHANGED
@@ -22,7 +22,7 @@ except RuntimeError:
22
 
23
  import torch
24
  import torch.nn as nn
25
- from torch.utils.data import DataLoader
26
  from torch.optim import AdamW
27
 
28
  # --- Accelerate & Transformers ---
@@ -248,10 +248,28 @@ def main() -> None:
248
  if len(dataset) == 0:
249
  raise RuntimeError("Dataset is empty.")
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  dataloader = DataLoader(
252
  dataset,
253
  batch_size=batch_size,
254
- shuffle=bool(args.shuffle),
 
255
  num_workers=int(args.num_workers),
256
  pin_memory=bool(args.pin_memory),
257
  collate_fn=functools.partial(filtered_collate, collator)
 
22
 
23
  import torch
24
  import torch.nn as nn
25
+ from torch.utils.data import DataLoader, WeightedRandomSampler
26
  from torch.optim import AdamW
27
 
28
  # --- Accelerate & Transformers ---
 
248
  if len(dataset) == 0:
249
  raise RuntimeError("Dataset is empty.")
250
 
251
+ # --- NEW: Weighted Sampling Logic ---
252
+ sampler = None
253
+ shuffle = bool(args.shuffle)
254
+
255
+ # Check if dataset provides weights (from metadata.jsonl)
256
+ if hasattr(dataset, 'get_weights'):
257
+ weights = dataset.get_weights()
258
+ if weights is not None:
259
+ if shuffle:
260
+ logger.info("INFO: Class weights found. Using WeightedRandomSampler for balanced training.")
261
+ # Note: WeightedRandomSampler requires shuffle=False in DataLoader
262
+ # It draws samples with replacement by default.
263
+ sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)
264
+ shuffle = False
265
+ else:
266
+ logger.info("INFO: Weights found but shuffle=False. Ignoring weights (sequential mode).")
267
+
268
  dataloader = DataLoader(
269
  dataset,
270
  batch_size=batch_size,
271
+ shuffle=shuffle,
272
+ sampler=sampler,
273
  num_workers=int(args.num_workers),
274
  pin_memory=bool(args.pin_memory),
275
  collate_fn=functools.partial(filtered_collate, collator)
train.sh CHANGED
@@ -1,4 +1,4 @@
1
- /venv/main/bin/accelerate launch train.py \
2
  --epochs 10 \
3
  --batch_size 1 \
4
  --learning_rate 1e-4 \
@@ -11,8 +11,8 @@
11
  --tensorboard_dir runs/oracle \
12
  --checkpoint_dir checkpoints \
13
  --mixed_precision bf16 \
14
- --max_seq_len 4096 \
15
- --horizons_seconds 30 60 120 240 420 \
16
  --quantiles 0.1 0.5 0.9 \
17
  --ohlc_stats_path ./data/ohlc_stats.npz \
18
  --num_workers 4 \
 
1
+ accelerate launch train.py \
2
  --epochs 10 \
3
  --batch_size 1 \
4
  --learning_rate 1e-4 \
 
11
  --tensorboard_dir runs/oracle \
12
  --checkpoint_dir checkpoints \
13
  --mixed_precision bf16 \
14
+ --max_seq_len 8192 \
15
+ --horizons_seconds 60 180 300 600 1800 3600 7200 \
16
  --quantiles 0.1 0.5 0.9 \
17
  --ohlc_stats_path ./data/ohlc_stats.npz \
18
  --num_workers 4 \